In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,"
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt

# torch.random.manual_seed(42)
# np.random.seed(42)

Extension horovod.torch has not been built: /home/fzj/.local/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


In [2]:
class DiscriminatorDataset(Dataset):
    def __init__(self, image_list, label_list, transform=None):
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        label = self.label_list[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
for dname in ['mosaic_2_112', 'bezier224_5_0.2_0.05_1d1']:
    for run in range(1, 5):
        train_image_dir = 'data/BCSS-WSSS/training'
        no_stain_norm_synthesize_image_dir = f"data/BCSS-WSSS/{dname}_run{run}/img"
        log_dir = f'discriminate_logs/bcss_disc_{dname}_run{run}.txt'

        train_image_list = sorted([os.path.join(train_image_dir, i) for i in os.listdir(train_image_dir) if ".png" in i])
        real_image_list = []
        for image_name in train_image_list:
            label_str = image_name.split(']')[0].split('[')[-1]
            label = [int(i) for i in label_str]
            if sum(label) > 1:
                real_image_list.append(image_name)

        synthesize_image_list = sorted([os.path.join(no_stain_norm_synthesize_image_dir, i) for i in os.listdir(no_stain_norm_synthesize_image_dir) if ".png" in i])

        np.random.shuffle(real_image_list)
        np.random.shuffle(synthesize_image_list)

        train_real_image_list = real_image_list[:int(len(real_image_list)*0.8)]
        val_real_image_list = real_image_list[int(len(real_image_list)*0.8):]

        train_synthesize_image_list = synthesize_image_list[:int(len(synthesize_image_list)*0.8)]
        val_synthesize_image_list = synthesize_image_list[int(len(synthesize_image_list)*0.8):]

        train_image_list = train_real_image_list[:5000] + train_synthesize_image_list[:5000]
        train_label_list = [1] * 5000 + [0] * 5000

        val_image_list = val_real_image_list + val_synthesize_image_list
        val_label_list = [1] * len(val_real_image_list) + [0] * len(val_synthesize_image_list)

        train_dataset = DiscriminatorDataset(train_image_list, train_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        val_dataset = DiscriminatorDataset(val_image_list, val_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
        val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

        model = timm.create_model('resnet18', pretrained=False, num_classes=2)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        # criterion = nn.BCEWithLogitsLoss()
        criterion = nn.CrossEntropyLoss()
        metrics = torchmetrics.MetricCollection({ 
            'acc': torchmetrics.Accuracy(num_classes=2, average='micro'), 
            'prec': torchmetrics.Precision(num_classes=2, average='macro'), 
            'rec': torchmetrics.Recall(num_classes=2, average='macro'),
            'f1': torchmetrics.F1Score(num_classes=2, average='macro')
        }).cuda()

        model = model.cuda()

        EPOCH_NUM = 3 if dname == 'mosaic_2_112' else 5
        best_acc = 0
        train_loss_list = []
        val_metrics_list = []
        for epoch in range(EPOCH_NUM):
            print(f"Epoch {epoch+1}/{EPOCH_NUM}")
            model.train()
            loss_list = []
            for batch in tqdm(train_dataloader, desc='train'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                optimizer.zero_grad()
                output = model(image) # 1 real, 0 fake
                loss = criterion(output, label)
                loss_list.append(loss.item())
                loss.backward()
                optimizer.step()
            print(f'train_loss: {np.mean(loss_list)}')
            train_loss_list.append(np.mean(loss_list))

            model.eval()
            for batch in tqdm(val_dataloader, desc='val'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                with torch.no_grad():
                    output = model(image)
                metrics(output, label)
            
            metrics_dict = metrics.compute()
            metrics.reset()
            val_metrics_list.append(metrics_dict)
            if metrics_dict['acc'] > best_acc:
                best_acc = metrics_dict['acc']
                torch.save(model.state_dict(), f"weights/dis_bcss_{dname}_r18_e{EPOCH_NUM}_run{run}.pth")
            print({k: v.item() for k, v in metrics_dict.items()})

            with open(log_dir, 'a') as f:
                f.write(f"Epoch {epoch+1}/{EPOCH_NUM}\n")
                f.write(f'train_loss: {np.mean(loss_list)}\n')
                f.write(str({k: v.item() for k, v in metrics_dict.items()}) + '\n')

        plt.plot(train_loss_list, label='train_loss')
        plt.plot([i['acc'].item() for i in val_metrics_list], label='acc')
        plt.plot([i['prec'].item() for i in val_metrics_list], label='prec')
        plt.plot([i['rec'].item() for i in val_metrics_list], label='rec')
        plt.plot([i['f1'].item() for i in val_metrics_list], label='f1')
        plt.legend()
        plt.savefig(f'discriminate_logs/dis_bcss_{dname}_r18_e{EPOCH_NUM}_run{run}.png')
        plt.clf()
        plt.close()

Epoch 1/3


train: 100%|██████████| 40/40 [00:33<00:00,  1.18it/s]


train_loss: 0.45690586790442467


val: 100%|██████████| 19/19 [00:19<00:00,  1.01s/it]


{'acc': 0.43439385294914246, 'f1': 0.3283298909664154, 'prec': 0.7118644118309021, 'rec': 0.5158419609069824}
Epoch 2/3


train: 100%|██████████| 40/40 [00:13<00:00,  3.06it/s]


train_loss: 0.029635536746354774


val: 100%|██████████| 19/19 [00:06<00:00,  2.94it/s]


{'acc': 0.9735912084579468, 'f1': 0.9725513458251953, 'prec': 0.9783719778060913, 'rec': 0.968250036239624}
Epoch 3/3


train: 100%|██████████| 40/40 [00:12<00:00,  3.12it/s]


train_loss: 0.004228774692455773


val: 100%|██████████| 19/19 [00:06<00:00,  2.85it/s]


{'acc': 0.9265959858894348, 'f1': 0.9259300231933594, 'prec': 0.9249893426895142, 'rec': 0.9371662139892578}
Epoch 1/3


train: 100%|██████████| 40/40 [00:22<00:00,  1.76it/s]


train_loss: 0.46285262778401376


val: 100%|██████████| 19/19 [00:11<00:00,  1.62it/s]


{'acc': 0.5739238858222961, 'f1': 0.5436827540397644, 'prec': 0.7463169097900391, 'rec': 0.6352074146270752}
Epoch 2/3


train: 100%|██████████| 40/40 [00:16<00:00,  2.41it/s]


train_loss: 0.0316638804768445


val: 100%|██████████| 19/19 [00:06<00:00,  2.80it/s]


{'acc': 0.9995841383934021, 'f1': 0.9995719194412231, 'prec': 0.9996442794799805, 'rec': 0.999500036239624}
Epoch 3/3


train: 100%|██████████| 40/40 [00:17<00:00,  2.35it/s]


train_loss: 0.002293558388191741


val: 100%|██████████| 19/19 [00:06<00:00,  2.82it/s]


{'acc': 0.9958411455154419, 'f1': 0.9957261085510254, 'prec': 0.9950494766235352, 'rec': 0.996440052986145}
Epoch 1/3


train: 100%|██████████| 40/40 [00:19<00:00,  2.02it/s]


train_loss: 0.4599551848135889


val: 100%|██████████| 19/19 [00:10<00:00,  1.75it/s]


{'acc': 0.4217092990875244, 'f1': 0.3048113286495209, 'prec': 0.7091612815856934, 'rec': 0.5049839615821838}
Epoch 2/3


train: 100%|██████████| 40/40 [00:15<00:00,  2.58it/s]


train_loss: 0.03467919232789427


val: 100%|██████████| 19/19 [00:06<00:00,  2.79it/s]


{'acc': 0.9800374507904053, 'f1': 0.9793045520782471, 'prec': 0.9834767580032349, 'rec': 0.9760000109672546}
Epoch 3/3


train: 100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


train_loss: 0.014059252038714476


val: 100%|██████████| 19/19 [00:06<00:00,  2.91it/s]


{'acc': 0.41588687896728516, 'f1': 0.29372888803482056, 'prec': 0.20794343948364258, 'rec': 0.5}
Epoch 1/3


train: 100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


train_loss: 0.47491582706570623


val: 100%|██████████| 19/19 [00:07<00:00,  2.48it/s]


{'acc': 0.6288209557533264, 'f1': 0.6129434108734131, 'prec': 0.7637209892272949, 'rec': 0.682199239730835}
Epoch 2/3


train: 100%|██████████| 40/40 [00:13<00:00,  2.96it/s]


train_loss: 0.04768384873168543


val: 100%|██████████| 19/19 [00:06<00:00,  2.86it/s]


{'acc': 0.945726752281189, 'f1': 0.9445104002952576, 'prec': 0.9422019720077515, 'rec': 0.9476381540298462}
Epoch 3/3


train: 100%|██████████| 40/40 [00:13<00:00,  2.91it/s]


train_loss: 0.03888216817285865


val: 100%|██████████| 19/19 [00:06<00:00,  2.80it/s]


{'acc': 0.9977126121520996, 'f1': 0.9976444244384766, 'prec': 0.9979745149612427, 'rec': 0.9973219633102417}
Epoch 1/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.76it/s]


train_loss: 0.536664330214262


val: 100%|██████████| 19/19 [00:07<00:00,  2.56it/s]


{'acc': 0.621334969997406, 'f1': 0.4656100869178772, 'prec': 0.7638362646102905, 'rec': 0.5459740161895752}
Epoch 2/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.93it/s]


train_loss: 0.15183138456195594


val: 100%|██████████| 19/19 [00:06<00:00,  2.74it/s]


{'acc': 0.7001455426216125, 'f1': 0.695027232170105, 'prec': 0.788831353187561, 'rec': 0.742965042591095}
Epoch 3/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.91it/s]


train_loss: 0.028922945738304406


val: 100%|██████████| 19/19 [00:07<00:00,  2.70it/s]


{'acc': 0.996880829334259, 'f1': 0.9967864751815796, 'prec': 0.9973441958427429, 'rec': 0.9962500333786011}
Epoch 4/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.92it/s]


train_loss: 0.01689873365103267


val: 100%|██████████| 19/19 [00:06<00:00,  2.75it/s]


{'acc': 0.4187980890274048, 'f1': 0.29929330945014954, 'prec': 0.7085505723953247, 'rec': 0.5024920105934143}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.85it/s]


train_loss: 0.04337822169763968


val: 100%|██████████| 19/19 [00:07<00:00,  2.70it/s]


{'acc': 0.9908505082130432, 'f1': 0.9905532598495483, 'prec': 0.9922888278961182, 'rec': 0.9889999628067017}
Epoch 1/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


train_loss: 0.48895538598299026


val: 100%|██████████| 19/19 [00:06<00:00,  2.73it/s]


{'acc': 0.6625077724456787, 'f1': 0.6621431112289429, 'prec': 0.6729483604431152, 'rec': 0.6761147975921631}
Epoch 2/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.79it/s]


train_loss: 0.10969736934639514


val: 100%|██████████| 19/19 [00:06<00:00,  2.75it/s]


{'acc': 0.9369931221008301, 'f1': 0.9340372085571289, 'prec': 0.9444666504859924, 'rec': 0.9276340007781982}
Epoch 3/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.84it/s]


train_loss: 0.04465504405088723


val: 100%|██████████| 19/19 [00:06<00:00,  2.91it/s]


{'acc': 0.5874401926994324, 'f1': 0.5640907287597656, 'prec': 0.7335453033447266, 'rec': 0.6446173787117004}
Epoch 4/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.94it/s]


train_loss: 0.09881126694381237


val: 100%|██████████| 19/19 [00:06<00:00,  2.84it/s]


{'acc': 0.9609066247940063, 'f1': 0.9591508507728577, 'prec': 0.9686353206634521, 'rec': 0.953000009059906}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.77it/s]


train_loss: 0.021565171785186976


val: 100%|██████████| 19/19 [00:06<00:00,  2.73it/s]


{'acc': 0.7013932466506958, 'f1': 0.6963624954223633, 'prec': 0.7896836400032043, 'rec': 0.7441050410270691}
Epoch 1/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


train_loss: 0.4742809824645519


val: 100%|██████████| 19/19 [00:06<00:00,  2.73it/s]


{'acc': 0.5851528644561768, 'f1': 0.3714678883552551, 'prec': 0.7923605442047119, 'rec': 0.5012500286102295}
Epoch 2/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


train_loss: 0.09162341065239162


val: 100%|██████████| 19/19 [00:06<00:00,  2.80it/s]


{'acc': 0.7040964961051941, 'f1': 0.7011935114860535, 'prec': 0.7018338441848755, 'rec': 0.7075386047363281}
Epoch 3/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.72it/s]


train_loss: 0.1037307508289814


val: 100%|██████████| 19/19 [00:06<00:00,  2.81it/s]


{'acc': 0.9893949031829834, 'f1': 0.9890611171722412, 'prec': 0.9902423620223999, 'rec': 0.9879699945449829}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.81it/s]


train_loss: 0.016067964711692185


val: 100%|██████████| 19/19 [00:06<00:00,  2.81it/s]


{'acc': 0.9960490465164185, 'f1': 0.9959295988082886, 'prec': 0.9964863061904907, 'rec': 0.9953939914703369}
Epoch 5/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.84it/s]


train_loss: 0.00786226501732017


val: 100%|██████████| 19/19 [00:06<00:00,  2.78it/s]


{'acc': 0.9916822910308838, 'f1': 0.9914250373840332, 'prec': 0.9923279881477356, 'rec': 0.9905760288238525}
Epoch 1/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.99it/s]


train_loss: 0.5220331743359565


val: 100%|██████████| 19/19 [00:06<00:00,  2.72it/s]


{'acc': 0.5643584728240967, 'f1': 0.5537344217300415, 'prec': 0.6352036595344543, 'rec': 0.6075072884559631}
Epoch 2/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.86it/s]


train_loss: 0.15109115885570645


val: 100%|██████████| 19/19 [00:06<00:00,  2.77it/s]


{'acc': 0.8247036933898926, 'f1': 0.8013138771057129, 'prec': 0.8842260837554932, 'rec': 0.789322018623352}
Epoch 3/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.70it/s]


train_loss: 0.02457386398309609


val: 100%|██████████| 19/19 [00:06<00:00,  2.85it/s]


{'acc': 0.9677687883377075, 'f1': 0.9670587778091431, 'prec': 0.9645254015922546, 'rec': 0.9705380797386169}
Epoch 4/5


train: 100%|██████████| 40/40 [00:15<00:00,  2.64it/s]


train_loss: 0.01762622712412849


val: 100%|██████████| 19/19 [00:06<00:00,  3.00it/s]


{'acc': 0.6294447779655457, 'f1': 0.6143638491630554, 'prec': 0.759251594543457, 'rec': 0.6820132732391357}
Epoch 5/5


train: 100%|██████████| 40/40 [00:16<00:00,  2.50it/s]


train_loss: 0.011248814026475883


val: 100%|██████████| 19/19 [00:06<00:00,  3.04it/s]


{'acc': 0.9611145853996277, 'f1': 0.9597862958908081, 'prec': 0.9623170495033264, 'rec': 0.9576420187950134}
