In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,"
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt


# torch.random.manual_seed(42)
# np.random.seed(42)

Extension horovod.torch has not been built: /home/fzj/.local/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


In [2]:
class DiscriminatorDataset(Dataset):
    def __init__(self, image_list, label_list, transform=None):
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        label = self.label_list[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
for dname in ['mosaic_2_112', 'bezier224_5_0.2_0.05_1d1']:
    for run in range(1, 5):
        train_image_dir = f'data/LUAD-HistoSeg/train'
        no_stain_norm_synthesize_image_dir = f"data/LUAD-HistoSeg/{dname}_run{run}/img"
        log_dir = f'discriminate_logs/luad_disc_{dname}_run{run}.txt'

        train_image_list = sorted([os.path.join(train_image_dir, i) for i in os.listdir(train_image_dir) if ".png" in i])
        real_image_list = []
        for image_name in train_image_list:
            label_str = image_name.split(']')[0].split('[')[-1]
            label_str = label_str.replace(' ', ',')
            label = eval(label_str)
            if sum(label) > 1:
                real_image_list.append(image_name)

        synthesize_image_list = sorted([os.path.join(no_stain_norm_synthesize_image_dir, i) for i in os.listdir(no_stain_norm_synthesize_image_dir) if ".png" in i])

        np.random.shuffle(real_image_list)
        np.random.shuffle(synthesize_image_list)

        train_real_image_list = real_image_list[:int(len(real_image_list)*0.8)]
        val_real_image_list = real_image_list[int(len(real_image_list)*0.8):]

        train_synthesize_image_list = synthesize_image_list[:int(len(synthesize_image_list)*0.8)]
        val_synthesize_image_list = synthesize_image_list[int(len(synthesize_image_list)*0.8):]

        train_image_list = train_real_image_list[:5000] + train_synthesize_image_list[:5000]
        train_label_list = [1] * 5000 + [0] * 5000

        val_image_list = val_real_image_list + val_synthesize_image_list
        val_label_list = [1] * len(val_real_image_list) + [0] * len(val_synthesize_image_list)

        train_dataset = DiscriminatorDataset(train_image_list, train_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        val_dataset = DiscriminatorDataset(val_image_list, val_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
        val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

        model = timm.create_model('resnet18', pretrained=False, num_classes=2)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        # criterion = nn.BCEWithLogitsLoss()
        criterion = nn.CrossEntropyLoss()
        metrics = torchmetrics.MetricCollection({ 
            'acc': torchmetrics.Accuracy(num_classes=2, average='micro'), 
            'prec': torchmetrics.Precision(num_classes=2, average='macro'), 
            'rec': torchmetrics.Recall(num_classes=2, average='macro'),
            'f1': torchmetrics.F1Score(num_classes=2, average='macro')
        }).cuda()

        model = model.cuda()

        EPOCH_NUM = 5
        best_acc = 0
        train_loss_list = []
        val_metrics_list = []
        for epoch in range(EPOCH_NUM):
            print(f"Epoch {epoch+1}/{EPOCH_NUM}")
            model.train()
            loss_list = []
            for batch in tqdm(train_dataloader, desc='train'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                optimizer.zero_grad()
                output = model(image) # 1 real, 0 fake
                loss = criterion(output, label)
                loss_list.append(loss.item())
                loss.backward()
                optimizer.step()
            print(f'train_loss: {np.mean(loss_list)}')
            train_loss_list.append(np.mean(loss_list))

            model.eval()
            for batch in tqdm(val_dataloader, desc='val'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                with torch.no_grad():
                    output = model(image)
                metrics(output, label)
            
            metrics_dict = metrics.compute()
            metrics.reset()
            val_metrics_list.append(metrics_dict)
            if metrics_dict['acc'] > best_acc:
                best_acc = metrics_dict['acc']
                torch.save(model.state_dict(), f"weights/dis_luad_{dname}_r18_e{EPOCH_NUM}_run{run}.pth")
            print({k: v.item() for k, v in metrics_dict.items()})

            with open(log_dir, 'a') as f:
                f.write(f"Epoch {epoch+1}/{EPOCH_NUM}\n")
                f.write(f'train_loss: {np.mean(loss_list)}\n')
                f.write(str({k: v.item() for k, v in metrics_dict.items()}) + '\n')
            
        plt.plot(train_loss_list, label='train_loss')
        plt.plot([i['acc'].item() for i in val_metrics_list], label='acc')
        plt.plot([i['prec'].item() for i in val_metrics_list], label='prec')
        plt.plot([i['rec'].item() for i in val_metrics_list], label='rec')
        plt.plot([i['f1'].item() for i in val_metrics_list], label='f1')
        plt.legend()
        plt.savefig(f'discriminate_logs/dis_luad_{dname}_r18_e{EPOCH_NUM}_run{run}.png')
        plt.clf()
        plt.close()


Epoch 1/5


train: 100%|██████████| 40/40 [00:38<00:00,  1.05it/s]


train_loss: 0.2802371967583895


val: 100%|██████████| 18/18 [00:17<00:00,  1.01it/s]


{'acc': 0.45279601216316223, 'f1': 0.3116721212863922, 'prec': 0.22639800608158112, 'rec': 0.5}
Epoch 2/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.49it/s]


train_loss: 0.11707659251987934


val: 100%|██████████| 18/18 [00:05<00:00,  3.24it/s]


{'acc': 0.7138329148292542, 'f1': 0.7029083967208862, 'prec': 0.8063725233078003, 'rec': 0.7385188341140747}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


train_loss: 0.05973495030775666


val: 100%|██████████| 18/18 [00:05<00:00,  3.02it/s]


{'acc': 0.7391895055770874, 'f1': 0.7015282511711121, 'prec': 0.8386102914810181, 'rec': 0.7120000123977661}
Epoch 4/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.51it/s]


train_loss: 0.027146280149463565


val: 100%|██████████| 18/18 [00:05<00:00,  3.21it/s]


{'acc': 0.9051392078399658, 'f1': 0.9016323089599609, 'prec': 0.9259368181228638, 'rec': 0.895293116569519}
Epoch 5/5


train: 100%|██████████| 40/40 [00:11<00:00,  3.49it/s]


train_loss: 0.01846548995235935


val: 100%|██████████| 18/18 [00:06<00:00,  2.93it/s]


{'acc': 0.8716323375701904, 'f1': 0.8714957237243652, 'prec': 0.8890786170959473, 'rec': 0.8825764656066895}
Epoch 1/5


train: 100%|██████████| 40/40 [00:20<00:00,  1.96it/s]


train_loss: 0.26361997742205856


val: 100%|██████████| 18/18 [00:10<00:00,  1.73it/s]


{'acc': 0.459587961435318, 'f1': 0.32539594173431396, 'prec': 0.7279462218284607, 'rec': 0.5062060356140137}
Epoch 2/5


train: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s]


train_loss: 0.10109653323888779


val: 100%|██████████| 18/18 [00:05<00:00,  3.02it/s]


{'acc': 0.8788770437240601, 'f1': 0.8788098692893982, 'prec': 0.893433690071106, 'rec': 0.8890236616134644}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


train_loss: 0.05778099694289267


val: 100%|██████████| 18/18 [00:06<00:00,  2.82it/s]


{'acc': 0.7575277090072632, 'f1': 0.7520928382873535, 'prec': 0.8256268501281738, 'rec': 0.7784443497657776}
Epoch 4/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.54it/s]


train_loss: 0.03289669820223935


val: 100%|██████████| 18/18 [00:06<00:00,  2.80it/s]


{'acc': 0.995245635509491, 'f1': 0.9952014684677124, 'prec': 0.9953560829162598, 'rec': 0.995051920413971}
Epoch 5/5


train: 100%|██████████| 40/40 [00:16<00:00,  2.40it/s]


train_loss: 0.024495779420249164


val: 100%|██████████| 18/18 [00:06<00:00,  2.84it/s]


{'acc': 0.9320806264877319, 'f1': 0.9302634596824646, 'prec': 0.9444923400878906, 'rec': 0.925086259841919}
Epoch 1/5


train: 100%|██████████| 40/40 [00:17<00:00,  2.23it/s]


train_loss: 0.2909013306722045


val: 100%|██████████| 18/18 [00:10<00:00,  1.80it/s]


{'acc': 0.45279601216316223, 'f1': 0.3116721212863922, 'prec': 0.22639800608158112, 'rec': 0.5}
Epoch 2/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.55it/s]


train_loss: 0.09893048452213407


val: 100%|██████████| 18/18 [00:06<00:00,  2.90it/s]


{'acc': 0.9669458866119385, 'f1': 0.9665753841400146, 'prec': 0.9677733182907104, 'rec': 0.9656134843826294}
Epoch 3/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.66it/s]


train_loss: 0.073856018204242


val: 100%|██████████| 18/18 [00:05<00:00,  3.02it/s]


{'acc': 0.6884763240814209, 'f1': 0.6730648279190063, 'prec': 0.7962085008621216, 'rec': 0.7153496146202087}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.80it/s]


train_loss: 0.038868117006495596


val: 100%|██████████| 18/18 [00:05<00:00,  3.10it/s]


{'acc': 0.8766130805015564, 'f1': 0.8705223798751831, 'prec': 0.9075295925140381, 'rec': 0.8638362884521484}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.83it/s]


train_loss: 0.020007005403749646


val: 100%|██████████| 18/18 [00:06<00:00,  2.97it/s]


{'acc': 0.9393253326416016, 'f1': 0.9392561912536621, 'prec': 0.9409171342849731, 'rec': 0.9445593357086182}
Epoch 1/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.73it/s]


train_loss: 0.28976317290216685


val: 100%|██████████| 18/18 [00:06<00:00,  2.69it/s]


{'acc': 0.45279601216316223, 'f1': 0.3116721212863922, 'prec': 0.22639800608158112, 'rec': 0.5}
Epoch 2/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.71it/s]


train_loss: 0.1267036149278283


val: 100%|██████████| 18/18 [00:06<00:00,  2.86it/s]


{'acc': 0.7475662231445312, 'f1': 0.7411009073257446, 'prec': 0.8210272789001465, 'rec': 0.7693421840667725}
Epoch 3/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.79it/s]


train_loss: 0.0936658188700676


val: 100%|██████████| 18/18 [00:06<00:00,  2.98it/s]


{'acc': 0.928911030292511, 'f1': 0.9283837080001831, 'prec': 0.92781662940979, 'rec': 0.929091215133667}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.79it/s]


train_loss: 0.07228850149549544


val: 100%|██████████| 18/18 [00:06<00:00,  2.73it/s]


{'acc': 0.5693910121917725, 'f1': 0.5145161747932434, 'prec': 0.7562788724899292, 'rec': 0.6065370440483093}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


train_loss: 0.03547874209471047


val: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


{'acc': 0.9800769686698914, 'f1': 0.9798358678817749, 'prec': 0.981650710105896, 'rec': 0.9784744381904602}
Epoch 1/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.64it/s]


train_loss: 0.5485511988401413


val: 100%|██████████| 18/18 [00:06<00:00,  2.69it/s]


{'acc': 0.7711116075515747, 'f1': 0.7710365056991577, 'prec': 0.7747883796691895, 'rec': 0.7763209342956543}
Epoch 2/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.62it/s]


train_loss: 0.44764099791646006


val: 100%|██████████| 18/18 [00:06<00:00,  2.90it/s]


{'acc': 0.7790355682373047, 'f1': 0.7773261666297913, 'prec': 0.7770300507545471, 'rec': 0.7776954174041748}
Epoch 3/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.63it/s]


train_loss: 0.3693452216684818


val: 100%|██████████| 18/18 [00:06<00:00,  2.74it/s]


{'acc': 0.7756395936012268, 'f1': 0.7522942423820496, 'prec': 0.8439750671386719, 'rec': 0.7533714175224304}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.71it/s]


train_loss: 0.30874909050762656


val: 100%|██████████| 18/18 [00:05<00:00,  3.06it/s]


{'acc': 0.8034865260124207, 'f1': 0.8030539751052856, 'prec': 0.8238294124603271, 'rec': 0.8151764869689941}
Epoch 5/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.90it/s]


train_loss: 0.2935459241271019


val: 100%|██████████| 18/18 [00:06<00:00,  2.83it/s]


{'acc': 0.8222775459289551, 'f1': 0.8119735717773438, 'prec': 0.854107141494751, 'rec': 0.8073731064796448}
Epoch 1/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.70it/s]


train_loss: 0.5062922216951847


val: 100%|██████████| 18/18 [00:06<00:00,  2.75it/s]


{'acc': 0.6651573181152344, 'f1': 0.662660539150238, 'prec': 0.6624714732170105, 'rec': 0.6629440784454346}
Epoch 2/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.64it/s]


train_loss: 0.3938666068017483


val: 100%|██████████| 18/18 [00:06<00:00,  2.64it/s]


{'acc': 0.726963996887207, 'f1': 0.7244847416877747, 'prec': 0.7614856958389282, 'rec': 0.7425377368927002}
Epoch 3/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.67it/s]


train_loss: 0.34411893002688887


val: 100%|██████████| 18/18 [00:06<00:00,  2.71it/s]


{'acc': 0.8539732694625854, 'f1': 0.8505553007125854, 'prec': 0.8597240447998047, 'rec': 0.8472470045089722}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.67it/s]


train_loss: 0.3000147119164467


val: 100%|██████████| 18/18 [00:06<00:00,  2.96it/s]


{'acc': 0.8283903002738953, 'f1': 0.8201584815979004, 'prec': 0.8517694473266602, 'rec': 0.815460205078125}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.83it/s]


train_loss: 0.2662860106676817


val: 100%|██████████| 18/18 [00:05<00:00,  3.01it/s]


{'acc': 0.8777450919151306, 'f1': 0.8738411664962769, 'prec': 0.8915382623672485, 'rec': 0.86888188123703}
Epoch 1/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.90it/s]


train_loss: 0.5165678188204765


val: 100%|██████████| 18/18 [00:05<00:00,  3.01it/s]


{'acc': 0.6465927362442017, 'f1': 0.5750828981399536, 'prec': 0.7375434637069702, 'rec': 0.6133299469947815}
Epoch 2/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.91it/s]


train_loss: 0.3906002588570118


val: 100%|██████████| 18/18 [00:07<00:00,  2.54it/s]


{'acc': 0.7013810276985168, 'f1': 0.6961784362792969, 'prec': 0.749882698059082, 'rec': 0.7201536893844604}
Epoch 3/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.53it/s]


train_loss: 0.3722902424633503


val: 100%|██████████| 18/18 [00:07<00:00,  2.38it/s]


{'acc': 0.7480190396308899, 'f1': 0.7471429109573364, 'prec': 0.7701915502548218, 'rec': 0.7603530883789062}
Epoch 4/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.56it/s]


train_loss: 0.32267925702035427


val: 100%|██████████| 18/18 [00:06<00:00,  2.65it/s]


{'acc': 0.6918722987174988, 'f1': 0.6784014105796814, 'prec': 0.7889788746833801, 'rec': 0.7175899744033813}
Epoch 5/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.52it/s]


train_loss: 0.28207519426941874


val: 100%|██████████| 18/18 [00:06<00:00,  2.67it/s]


{'acc': 0.7557165622711182, 'f1': 0.7552247047424316, 'prec': 0.7736767530441284, 'rec': 0.7667827606201172}
Epoch 1/5


train: 100%|██████████| 40/40 [00:16<00:00,  2.49it/s]


train_loss: 0.5190721236169338


val: 100%|██████████| 18/18 [00:06<00:00,  2.68it/s]


{'acc': 0.7394158840179443, 'f1': 0.7198967933654785, 'prec': 0.7696326971054077, 'rec': 0.7209194898605347}
Epoch 2/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.17it/s]


train_loss: 0.41114903166890143


val: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]


{'acc': 0.7394158840179443, 'f1': 0.7322622537612915, 'prec': 0.7414960861206055, 'rec': 0.7304948568344116}
Epoch 3/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.21it/s]


train_loss: 0.37119037061929705


val: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


{'acc': 0.8415213823318481, 'f1': 0.838133692741394, 'prec': 0.845445990562439, 'rec': 0.8353085517883301}
Epoch 4/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.13it/s]


train_loss: 0.2855131603777409


val: 100%|██████████| 18/18 [00:06<00:00,  2.90it/s]


{'acc': 0.8494453430175781, 'f1': 0.8441792726516724, 'prec': 0.8637049198150635, 'rec': 0.8395727872848511}
Epoch 5/5


train: 100%|██████████| 40/40 [00:12<00:00,  3.19it/s]


train_loss: 0.25537086501717565


val: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


{'acc': 0.5725605487823486, 'f1': 0.5195649266242981, 'prec': 0.755332887172699, 'rec': 0.6093469262123108}
