In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt


# torch.random.manual_seed(42)
# np.random.seed(42)

Extension horovod.torch has not been built: /home/fzj/.local/lib/python3.8/site-packages/horovod/torch/mpi_lib_v2.cpython-38-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


In [2]:
class DiscriminatorDataset(Dataset):
    def __init__(self, image_list, label_list, transform=None):
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        label = self.label_list[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
for dname in ['bezier224_5_0.2_0.05_1d1']: # ['mosaic_2_112']:
    for run in [4]: # range(1, 6):
        for N_sample in [10]: # [10, 20, 30, 40]:
            train_image_dir = f'data/LUAD-HistoSeg/train'
            no_stain_norm_synthesize_image_dir = f"data/LUAD-HistoSeg/limit_N/{dname}_N_{N_sample}_run{run}/img"
            log_dir = f'discriminate_logs_limit_N/luad_{dname}_N_{N_sample}_run{run}.txt'

            train_image_list = sorted([os.path.join(train_image_dir, i) for i in os.listdir(train_image_dir) if ".png" in i])
            real_image_list = []
            for image_name in train_image_list:
                label_str = image_name.split(']')[0].split('[')[-1]
                label_str = label_str.replace(' ', ',')
                label = eval(label_str)
                if sum(label) > 1:
                    real_image_list.append(image_name)

            synthesize_image_list = sorted([os.path.join(no_stain_norm_synthesize_image_dir, i) for i in os.listdir(no_stain_norm_synthesize_image_dir) if ".png" in i])

            np.random.shuffle(real_image_list)
            np.random.shuffle(synthesize_image_list)

            train_real_image_list = real_image_list[:int(len(real_image_list)*0.8)]
            val_real_image_list = real_image_list[int(len(real_image_list)*0.8):]

            train_synthesize_image_list = synthesize_image_list[:int(len(synthesize_image_list)*0.8)]
            val_synthesize_image_list = synthesize_image_list[int(len(synthesize_image_list)*0.8):]

            train_image_list = train_real_image_list[:5000] + train_synthesize_image_list[:5000]
            train_label_list = [1] * 5000 + [0] * 5000

            val_image_list = val_real_image_list + val_synthesize_image_list
            val_label_list = [1] * len(val_real_image_list) + [0] * len(val_synthesize_image_list)

            train_dataset = DiscriminatorDataset(train_image_list, train_label_list, transform=transforms.Compose([
                transforms.Resize((224,224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
            ]))

            val_dataset = DiscriminatorDataset(val_image_list, val_label_list, transform=transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
            ]))

            train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
            val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

            model = timm.create_model('resnet18', pretrained=False, num_classes=2)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            # criterion = nn.BCEWithLogitsLoss()
            criterion = nn.CrossEntropyLoss()
            metrics = torchmetrics.MetricCollection({ 
                'acc': torchmetrics.Accuracy(num_classes=2, average='micro'), 
                'prec': torchmetrics.Precision(num_classes=2, average='macro'), 
                'rec': torchmetrics.Recall(num_classes=2, average='macro'),
                'f1': torchmetrics.F1Score(num_classes=2, average='macro')
            }).cuda()

            model = model.cuda()

            EPOCH_NUM = 5
            best_acc = 0
            train_loss_list = []
            val_metrics_list = []
            for epoch in range(EPOCH_NUM):
                print(f"Epoch {epoch+1}/{EPOCH_NUM}")
                model.train()
                loss_list = []
                for batch in tqdm(train_dataloader, desc='train'):
                    image, label = batch
                    image = image.cuda()
                    label = label.cuda()
                    optimizer.zero_grad()
                    output = model(image) # 1 real, 0 fake
                    loss = criterion(output, label)
                    loss_list.append(loss.item())
                    loss.backward()
                    optimizer.step()
                print(f'train_loss: {np.mean(loss_list)}')
                train_loss_list.append(np.mean(loss_list))

                model.eval()
                for batch in tqdm(val_dataloader, desc='val'):
                    image, label = batch
                    image = image.cuda()
                    label = label.cuda()
                    with torch.no_grad():
                        output = model(image)
                    metrics(output, label)
                
                metrics_dict = metrics.compute()
                metrics.reset()
                val_metrics_list.append(metrics_dict)
                if metrics_dict['acc'] > best_acc:
                    best_acc = metrics_dict['acc']
                    torch.save(model.state_dict(), f"weights_limit_N/dis_luad_{dname}_N_{N_sample}_r18_e{EPOCH_NUM}_run{run}.pth")
                print({k: v.item() for k, v in metrics_dict.items()})

                with open(log_dir, 'a') as f:
                    f.write(f"Epoch {epoch+1}/{EPOCH_NUM}\n")
                    f.write(f'train_loss: {np.mean(loss_list)}\n')
                    f.write(str({k: v.item() for k, v in metrics_dict.items()}) + '\n')
                
            plt.plot(train_loss_list, label='train_loss')
            plt.plot([i['acc'].item() for i in val_metrics_list], label='acc')
            plt.plot([i['prec'].item() for i in val_metrics_list], label='prec')
            plt.plot([i['rec'].item() for i in val_metrics_list], label='rec')
            plt.plot([i['f1'].item() for i in val_metrics_list], label='f1')
            plt.legend()
            plt.savefig(f'discriminate_logs_limit_N/dis_luad_{dname}_N_{N_sample}_r18_e{EPOCH_NUM}_run{run}.png')
            plt.clf()
            plt.close()


Epoch 1/5


train: 100%|██████████| 40/40 [00:16<00:00,  2.39it/s]


train_loss: 0.41164139546453954


val: 100%|██████████| 18/18 [00:06<00:00,  2.74it/s]


{'acc': 0.6060674786567688, 'f1': 0.4891645908355713, 'prec': 0.7549657821655273, 'rec': 0.5659489035606384}
Epoch 2/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.92it/s]


train_loss: 0.11806804279331118


val: 100%|██████████| 18/18 [00:07<00:00,  2.47it/s]


{'acc': 0.9728322625160217, 'f1': 0.9726041555404663, 'prec': 0.9723385572433472, 'rec': 0.9728898406028748}
Epoch 3/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.89it/s]


train_loss: 0.03575878001283854


val: 100%|██████████| 18/18 [00:06<00:00,  2.63it/s]


{'acc': 0.9782657623291016, 'f1': 0.9781147241592407, 'prec': 0.9773366451263428, 'rec': 0.9791486263275146}
Epoch 4/5


train: 100%|██████████| 40/40 [00:14<00:00,  2.83it/s]


train_loss: 0.02312849119771272


val: 100%|██████████| 18/18 [00:07<00:00,  2.43it/s]


{'acc': 0.9916232824325562, 'f1': 0.9915533065795898, 'prec': 0.9912586212158203, 'rec': 0.9918714165687561}
Epoch 5/5


train: 100%|██████████| 40/40 [00:13<00:00,  2.90it/s]


train_loss: 0.01764708707924001


val: 100%|██████████| 18/18 [00:06<00:00,  2.64it/s]


{'acc': 0.9789449572563171, 'f1': 0.9788310527801514, 'prec': 0.9777830839157104, 'rec': 0.9807612895965576}


In [4]:
N_sample, run

(10, 4)