In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7,"
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt


# torch.random.manual_seed(42)
# np.random.seed(42)

Extension horovod.torch has not been built: /home/fzj/.local/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


In [2]:
class DiscriminatorDataset(Dataset):
    def __init__(self, image_list, label_list, transform=None):
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        label = self.label_list[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
for dname in ['bezier224_5_0.2_0.05_1d1','mosaic_2_112']:
    for run in [2, 5]: # range(1, 6):
        for N_sample in [40]: # [10, 20, 30, 40]:
            train_image_dir = f'data/LUAD-HistoSeg/train'
            no_stain_norm_synthesize_image_dir = f"data/LUAD-HistoSeg/limit_N/{dname}_N_{N_sample}_run{run}/img"
            log_dir = f'discriminate_logs_limit_N/luad_{dname}_N_{N_sample}_run{run}.txt'

            train_image_list = sorted([os.path.join(train_image_dir, i) for i in os.listdir(train_image_dir) if ".png" in i])
            real_image_list = []
            for image_name in train_image_list:
                label_str = image_name.split(']')[0].split('[')[-1]
                label_str = label_str.replace(' ', ',')
                label = eval(label_str)
                if sum(label) > 1:
                    real_image_list.append(image_name)

            synthesize_image_list = sorted([os.path.join(no_stain_norm_synthesize_image_dir, i) for i in os.listdir(no_stain_norm_synthesize_image_dir) if ".png" in i])

            np.random.shuffle(real_image_list)
            np.random.shuffle(synthesize_image_list)

            train_real_image_list = real_image_list[:int(len(real_image_list)*0.8)]
            val_real_image_list = real_image_list[int(len(real_image_list)*0.8):]

            train_synthesize_image_list = synthesize_image_list[:int(len(synthesize_image_list)*0.8)]
            val_synthesize_image_list = synthesize_image_list[int(len(synthesize_image_list)*0.8):]

            train_image_list = train_real_image_list[:5000] + train_synthesize_image_list[:5000]
            train_label_list = [1] * 5000 + [0] * 5000

            val_image_list = val_real_image_list + val_synthesize_image_list
            val_label_list = [1] * len(val_real_image_list) + [0] * len(val_synthesize_image_list)

            train_dataset = DiscriminatorDataset(train_image_list, train_label_list, transform=transforms.Compose([
                transforms.Resize((224,224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
            ]))

            val_dataset = DiscriminatorDataset(val_image_list, val_label_list, transform=transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
            ]))

            train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
            val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

            model = timm.create_model('resnet18', pretrained=False, num_classes=2)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            # criterion = nn.BCEWithLogitsLoss()
            criterion = nn.CrossEntropyLoss()
            metrics = torchmetrics.MetricCollection({ 
                'acc': torchmetrics.Accuracy(num_classes=2, average='micro'), 
                'prec': torchmetrics.Precision(num_classes=2, average='macro'), 
                'rec': torchmetrics.Recall(num_classes=2, average='macro'),
                'f1': torchmetrics.F1Score(num_classes=2, average='macro')
            }).cuda()

            neg_metrics = torchmetrics.MetricCollection({ 
                'acc': torchmetrics.Accuracy(task='binary'), 
                'prec': torchmetrics.Precision(task='binary'), 
                'rec': torchmetrics.Recall(task='binary'),
                'f1': torchmetrics.F1Score(task='binary')
            }).cuda()

            model = model.cuda()

            EPOCH_NUM = 5
            best_acc = 0
            train_loss_list = []
            val_metrics_list = []
            for epoch in range(EPOCH_NUM):
                print(f"Epoch {epoch+1}/{EPOCH_NUM}")
                model.train()
                loss_list = []
                for batch in tqdm(train_dataloader, desc='train'):
                    image, label = batch
                    image = image.cuda()
                    label = label.cuda()
                    optimizer.zero_grad()
                    output = model(image) # 1 real, 0 fake
                    loss = criterion(output, label)
                    loss_list.append(loss.item())
                    loss.backward()
                    optimizer.step()
                print(f'train_loss: {np.mean(loss_list)}')
                train_loss_list.append(np.mean(loss_list))

                model.eval()
                for batch in tqdm(val_dataloader, desc='val'):
                    image, label = batch
                    image = image.cuda()
                    label = label.cuda()
                    with torch.no_grad():
                        output = model(image)
                    metrics(output, label)
                    neg_metrics(preds=output[:,0], target=1-label)
                
                metrics_dict = metrics.compute()
                neg_metrics_dict = neg_metrics.compute()
                metrics.reset()
                neg_metrics.reset()

                val_metrics_list.append(metrics_dict)
                if metrics_dict['acc'] > best_acc:
                    best_acc = metrics_dict['acc']
                    torch.save(model.state_dict(), f"weights_limit_N/dis_luad_{dname}_N_{N_sample}_r18_e{EPOCH_NUM}_run{run}.pth")
                print({k: v.item() for k, v in metrics_dict.items()})
                print({k: v.item() for k, v in neg_metrics_dict.items()})

                with open(log_dir, 'a') as f:
                    f.write(f"Epoch {epoch+1}/{EPOCH_NUM}\n")
                    f.write(f'train_loss: {np.mean(loss_list)}\n')
                    f.write(str({k: v.item() for k, v in metrics_dict.items()}) + '\n')
                    f.write(str({k: v.item() for k, v in neg_metrics_dict.items()}) + '\n')
                
            plt.plot(train_loss_list, label='train_loss')
            plt.plot([i['acc'].item() for i in val_metrics_list], label='acc')
            plt.plot([i['prec'].item() for i in val_metrics_list], label='prec')
            plt.plot([i['rec'].item() for i in val_metrics_list], label='rec')
            plt.plot([i['f1'].item() for i in val_metrics_list], label='f1')
            plt.legend()
            plt.savefig(f'discriminate_logs_limit_N/dis_luad_{dname}_N_{N_sample}_r18_e{EPOCH_NUM}_run{run}.png')
            plt.clf()
            plt.close()


Epoch 1/5


train: 100%|██████████| 40/40 [00:27<00:00,  1.48it/s]


train_loss: 0.4667718060314655


val: 100%|██████████| 18/18 [00:12<00:00,  1.47it/s]


{'acc': 0.726963996887207, 'f1': 0.6946444511413574, 'prec': 0.7906365394592285, 'rec': 0.7025112509727478}
{'acc': 0.685080349445343, 'f1': 0.48652637004852295, 'prec': 0.929478108882904, 'rec': 0.3294999897480011}
Epoch 2/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.35it/s]


train_loss: 0.33256066106259824


val: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


{'acc': 0.7247000336647034, 'f1': 0.7222481966018677, 'prec': 0.758641242980957, 'rec': 0.7401671409606934}
{'acc': 0.764319658279419, 'f1': 0.7634628415107727, 'prec': 0.6997084617614746, 'rec': 0.8399999737739563}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.24it/s]


train_loss: 0.24787562638521193


val: 100%|██████████| 18/18 [00:06<00:00,  2.93it/s]


{'acc': 0.8582748174667358, 'f1': 0.858126163482666, 'prec': 0.8752943873405457, 'rec': 0.8690772652626038}
{'acc': 0.9058184027671814, 'f1': 0.8983381986618042, 'prec': 0.8785851001739502, 'rec': 0.9190000295639038}
Epoch 4/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.14it/s]


train_loss: 0.1626947395503521


val: 100%|██████████| 18/18 [00:06<00:00,  2.99it/s]


{'acc': 0.9234774708747864, 'f1': 0.9218626022338867, 'prec': 0.9304136037826538, 'rec': 0.9180879592895508}
{'acc': 0.8684627413749695, 'f1': 0.8321294784545898, 'prec': 0.9856262803077698, 'rec': 0.7200000286102295}
Epoch 5/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.24it/s]


train_loss: 0.13220207048580052


val: 100%|██████████| 18/18 [00:06<00:00,  2.99it/s]


{'acc': 0.9556260108947754, 'f1': 0.9554532766342163, 'prec': 0.9546167850494385, 'rec': 0.9582893252372742}
{'acc': 0.9576635956764221, 'f1': 0.9525741934776306, 'prec': 0.9665465950965881, 'rec': 0.9390000104904175}
Epoch 1/5


train: 100%|██████████| 40/40 [00:35<00:00,  1.13it/s]


train_loss: 0.46937166824936866


val: 100%|██████████| 18/18 [00:16<00:00,  1.08it/s]


{'acc': 0.7500566244125366, 'f1': 0.73487389087677, 'prec': 0.7720823884010315, 'rec': 0.7340066432952881}
{'acc': 0.689834713935852, 'f1': 0.518622636795044, 'prec': 0.8723404407501221, 'rec': 0.36899998784065247}
Epoch 2/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.40it/s]


train_loss: 0.3079887110739946


val: 100%|██████████| 18/18 [00:06<00:00,  3.00it/s]


{'acc': 0.7772243618965149, 'f1': 0.776780366897583, 'prec': 0.7958657145500183, 'rec': 0.7884624600410461}
{'acc': 0.8317862749099731, 'f1': 0.8169500231742859, 'prec': 0.8052452802658081, 'rec': 0.8289999961853027}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]


train_loss: 0.22617303170263767


val: 100%|██████████| 18/18 [00:06<00:00,  2.89it/s]


{'acc': 0.7849218845367432, 'f1': 0.7639349699020386, 'prec': 0.8492031097412109, 'rec': 0.7635782957077026}
{'acc': 0.7147384881973267, 'f1': 0.5424836874008179, 'prec': 0.9907161593437195, 'rec': 0.3734999895095825}
Epoch 4/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.33it/s]


train_loss: 0.1304917087778449


val: 100%|██████████| 18/18 [00:05<00:00,  3.04it/s]


{'acc': 0.9456644654273987, 'f1': 0.945421040058136, 'prec': 0.9444975256919861, 'rec': 0.9478068947792053}
{'acc': 0.948381245136261, 'f1': 0.9423076510429382, 'prec': 0.9538934230804443, 'rec': 0.9309999942779541}
Epoch 5/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


train_loss: 0.10917811095714569


val: 100%|██████████| 18/18 [00:06<00:00,  2.93it/s]


{'acc': 0.7987321615219116, 'f1': 0.7984120845794678, 'prec': 0.8165451288223267, 'rec': 0.8097108006477356}
{'acc': 0.8263527154922485, 'f1': 0.8235564231872559, 'prec': 0.7626757621765137, 'rec': 0.8949999809265137}
Epoch 1/5


train: 100%|██████████| 40/40 [00:32<00:00,  1.25it/s]


train_loss: 0.2643843815661967


val: 100%|██████████| 18/18 [00:13<00:00,  1.32it/s]


{'acc': 0.48041656613349915, 'f1': 0.3657628893852234, 'prec': 0.7328288555145264, 'rec': 0.5252379179000854}
{'acc': 0.503961980342865, 'f1': 0.6460991501808167, 'prec': 0.4772130846977234, 'rec': 1.0}
Epoch 2/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.25it/s]


train_loss: 0.08145746529335156


val: 100%|██████████| 18/18 [00:06<00:00,  2.97it/s]


{'acc': 0.9452117085456848, 'f1': 0.9451195001602173, 'prec': 0.9458686113357544, 'rec': 0.9498085379600525}
{'acc': 0.9621915221214294, 'f1': 0.9597104787826538, 'prec': 0.9272727370262146, 'rec': 0.9944999814033508}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.12it/s]


train_loss: 0.03469431370031088


val: 100%|██████████| 18/18 [00:05<00:00,  3.03it/s]


{'acc': 0.9848313331604004, 'f1': 0.9847276210784912, 'prec': 0.9839066863059998, 'rec': 0.9858379364013672}
{'acc': 0.9880009293556213, 'f1': 0.9867267608642578, 'prec': 0.988459587097168, 'rec': 0.9850000143051147}
Epoch 4/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]


train_loss: 0.04690607851371169


val: 100%|██████████| 18/18 [00:06<00:00,  2.99it/s]


{'acc': 0.8225039839744568, 'f1': 0.8212701082229614, 'prec': 0.8591954112052917, 'rec': 0.8378154635429382}
{'acc': 0.8562372922897339, 'f1': 0.862998902797699, 'prec': 0.759013295173645, 'rec': 1.0}
Epoch 5/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.25it/s]


train_loss: 0.03073437981074676


val: 100%|██████████| 18/18 [00:06<00:00,  3.00it/s]


{'acc': 0.843106210231781, 'f1': 0.8424911499023438, 'prec': 0.871333122253418, 'rec': 0.8566404581069946}
{'acc': 0.8648403882980347, 'f1': 0.8701326847076416, 'prec': 0.7701193690299988, 'rec': 1.0}
Epoch 1/5


train: 100%|██████████| 40/40 [00:21<00:00,  1.89it/s]


train_loss: 0.27514699213206767


val: 100%|██████████| 18/18 [00:10<00:00,  1.77it/s]


{'acc': 0.5204890370368958, 'f1': 0.4369925856590271, 'prec': 0.7428363561630249, 'rec': 0.5618535280227661}
{'acc': 0.5843332409858704, 'f1': 0.6854009628295898, 'prec': 0.5213764309883118, 'rec': 1.0}
Epoch 2/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.12it/s]


train_loss: 0.0637218679767102


val: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]


{'acc': 0.5825220942497253, 'f1': 0.5338701605796814, 'prec': 0.7601456642150879, 'rec': 0.6185353994369507}
{'acc': 0.6112746000289917, 'f1': 0.6996676921844482, 'prec': 0.5380683541297913, 'rec': 1.0}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.30it/s]


train_loss: 0.04632248259149492


val: 100%|██████████| 18/18 [00:05<00:00,  3.05it/s]


{'acc': 0.9261942505836487, 'f1': 0.9253885746002197, 'prec': 0.9263249039649963, 'rec': 0.9246247410774231}
{'acc': 0.9128367900848389, 'f1': 0.9006963968276978, 'prec': 0.9302077889442444, 'rec': 0.8730000257492065}
Epoch 4/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]


train_loss: 0.026226441957987845


val: 100%|██████████| 18/18 [00:05<00:00,  3.02it/s]


{'acc': 0.9850577116012573, 'f1': 0.9848917722702026, 'prec': 0.9860728979110718, 'rec': 0.983931303024292}
{'acc': 0.9782657623291016, 'f1': 0.9755102396011353, 'prec': 0.9958333373069763, 'rec': 0.9559999704360962}
Epoch 5/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.16it/s]


train_loss: 0.020533011126099156


val: 100%|██████████| 18/18 [00:06<00:00,  2.94it/s]


{'acc': 0.9178175330162048, 'f1': 0.9151722192764282, 'prec': 0.9347122311592102, 'rec': 0.909250020980835}
{'acc': 0.8929137587547302, 'f1': 0.865891695022583, 'prec': 1.0, 'rec': 0.7634999752044678}


In [4]:
N_sample, run

(40, 5)