In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torchmetrics
from tqdm import tqdm
import matplotlib.pyplot as plt



In [None]:
class DiscriminatorDataset(Dataset):
    def __init__(self, image_list, label_list, transform=None):
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        label = self.label_list[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
for dname in ['mosaic_2_112', 'bezier224_5_0.2_0.05_1d1']:
    for run in range(0, 1):
        train_image_dir = '../data/WSSS4LUAD/1.training'
        no_stain_norm_synthesize_image_dir = f"../data/WSSS4LUAD/{dname}_run{run}/img"
        log_dir = f'discriminate_logs/wsss4luad_disc_{dname}_run{run}.txt'

        train_image_list = sorted([os.path.join(train_image_dir, i) for i in os.listdir(train_image_dir) if ".png" in i])
        real_image_list = []
        for image_name in train_image_list:
            label_str = '[' + image_name.split(']')[0].split('[')[-1] + ']'
            label = eval(label_str)
            if sum(label) > 1:
                real_image_list.append(image_name)

        synthesize_image_list = sorted([os.path.join(no_stain_norm_synthesize_image_dir, i) for i in os.listdir(no_stain_norm_synthesize_image_dir) if ".png" in i])

        np.random.shuffle(real_image_list)
        np.random.shuffle(synthesize_image_list)

        train_real_image_list = real_image_list[:int(len(real_image_list)*0.8)]
        val_real_image_list = real_image_list[int(len(real_image_list)*0.8):]

        train_synthesize_image_list = synthesize_image_list[:int(len(synthesize_image_list)*0.8)]
        val_synthesize_image_list = synthesize_image_list[int(len(synthesize_image_list)*0.8):]

        train_image_list = train_real_image_list + train_synthesize_image_list
        train_label_list = [1] * len(train_real_image_list) + [0] * len(train_synthesize_image_list)

        val_image_list = val_real_image_list + val_synthesize_image_list
        val_label_list = [1] * len(val_real_image_list) + [0] * len(val_synthesize_image_list)


        train_dataset = DiscriminatorDataset(train_image_list, train_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        val_dataset = DiscriminatorDataset(val_image_list, val_label_list, transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ]))

        train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
        val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

        model = timm.create_model('resnet18', pretrained=False, num_classes=2)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        metrics = torchmetrics.MetricCollection({ 
            'acc': torchmetrics.Accuracy(num_classes=2, average='micro'), 
            'prec': torchmetrics.Precision(num_classes=2, average='macro'), 
            'rec': torchmetrics.Recall(num_classes=2, average='macro'),
            'f1': torchmetrics.F1Score(num_classes=2, average='macro')
        }).cuda()

        neg_metrics = torchmetrics.MetricCollection({ 
            'acc': torchmetrics.Accuracy(task='binary'), 
            'prec': torchmetrics.Precision(task='binary'), 
            'rec': torchmetrics.Recall(task='binary'),
            'f1': torchmetrics.F1Score(task='binary')
        }).cuda()

        model = model.cuda()

        EPOCH_NUM = 5
        best_acc = 0
        train_loss_list = []
        val_metrics_list = []
        for epoch in range(EPOCH_NUM):
            print(f"Epoch {epoch+1}/{EPOCH_NUM}")
            model.train()
            loss_list = []
            for batch in tqdm(train_dataloader, desc='train'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                optimizer.zero_grad()
                output = model(image) # 1 real, 0 fake
                loss = criterion(output, label)
                loss_list.append(loss.item())
                loss.backward()
                optimizer.step()
            print(f'train_loss: {np.mean(loss_list)}')
            train_loss_list.append(np.mean(loss_list))

            model.eval()
            for batch in tqdm(val_dataloader, desc='val'):
                image, label = batch
                image = image.cuda()
                label = label.cuda()
                with torch.no_grad():
                    output = model(image)
                metrics(output, label)
                neg_metrics(preds=output[:,0], target=1-label)
            
            metrics_dict = metrics.compute()
            neg_metrics_dict = neg_metrics.compute()
            metrics.reset()
            neg_metrics.reset()

            val_metrics_list.append(metrics_dict)
            if metrics_dict['acc'] > best_acc:
                best_acc = metrics_dict['acc']
                if not os.path.exists('weights'):
                    os.makedirs('weights')
                torch.save(model.state_dict(), f"weights/dis_{dname}_r18_e5_run{run}.pth")
            print({k: v.item() for k, v in metrics_dict.items()})
            print({k: v.item() for k, v in neg_metrics_dict.items()})

            if not os.path.exists('discriminate_logs'):
                os.makedirs('discriminate_logs')
            with open(log_dir, 'a') as f:
                f.write(f"Epoch {epoch+1}/{EPOCH_NUM}\n")
                f.write(f'train_loss: {np.mean(loss_list)}\n')
                f.write(str({k: v.item() for k, v in metrics_dict.items()}) + '\n')
                f.write(str({k: v.item() for k, v in neg_metrics_dict.items()}) + '\n')

        plt.plot(train_loss_list, label='train_loss')
        plt.plot([i['acc'].item() for i in val_metrics_list], label='acc')
        plt.plot([i['prec'].item() for i in val_metrics_list], label='prec')
        plt.plot([i['rec'].item() for i in val_metrics_list], label='rec')
        plt.plot([i['f1'].item() for i in val_metrics_list], label='f1')
        plt.legend()
        plt.savefig(f'discriminate_logs/dis_wsss4luad_{dname}_r18_e5_run{run}.png')
        plt.show()
        plt.clf()
        plt.close()