In [1]:
import os
from copy import deepcopy

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import STL10
from torchvision import transforms

from torch.utils.data import Dataset
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor





device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
CHECKPOINT_PATH = "../saved_models"


  set_matplotlib_formats('svg', 'pdf') # For export


<Figure size 640x480 with 0 Axes>

## 定義data set
## Labeled

In [2]:
class MyData(Dataset):

    def __init__(self, contrast_root_dir, label_dir,transform):
        self.contrast_root_dir = contrast_root_dir#根資料夾名
        self.label_dir = label_dir#Label資料夾名
        self.image_dir_path = os.path.join(self.contrast_root_dir, self.label_dir)#組合出圖片資料夾地址
        #讀出個資料夾為一個列表
        self.image_list = os.listdir(self.image_dir_path)#將圖片資料夾中的內容形成一個列表
        self.image_list.sort()
        self.transform = transform

    #取得某張圖片
    def __getitem__(self, idx):
        img_name = self.image_list[idx]#圖片檔名
        img_item_path = os.path.join(self.contrast_root_dir, self.label_dir, img_name)#定位圖片完整地址
        img = Image.open(img_item_path)
        label = self.label_dir#label資料夾名就是圖片對應的label
        if self.transform:
            img = self.transform(img)
        return img , label


    def __len__(self):
        return len(self.image_list)


## unlabeled

In [3]:
class unlabel_MyData(Dataset):

    def __init__(self, contrast_root_dir, label_dir,transform):
        self.contrast_root_dir = contrast_root_dir#根資料夾名
        self.label_dir = label_dir#Label資料夾名
        self.image_dir_path = os.path.join(self.contrast_root_dir, self.label_dir)#組合出圖片資料夾地址
        #讀出個資料夾為一個列表
        self.image_list = os.listdir(self.image_dir_path)#將圖片資料夾中的內容形成一個列表
        self.image_list.sort()
        self.transform = transform

    #取得某張圖片
    def __getitem__(self, idx):
        img_name = self.image_list[idx]#圖片檔名
        img_item_path = os.path.join(self.contrast_root_dir, self.label_dir, img_name)#定位圖片完整地址
        img = Image.open(img_item_path)
        label = -1
        if self.transform:
            img = self.transform(img)
        return img , label


    def __len__(self):
        return len(self.image_list)


# 定義argumentation過程
## n_view代表要經過幾次transform過程

In [4]:
class ContrastiveTransformations(object):
    
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views
        
    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

In [5]:
contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.RandomApply([
                                              transforms.ColorJitter(brightness=0.5, 
                                                                     contrast=0.5, 
                                                                     saturation=0.5, 
                                                                     hue=0.1)
                                          ], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          transforms.GaussianBlur(kernel_size=9),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5,), (0.5,))
                                         ])

In [6]:
contrast_root_dir = "E:/run_dataset_totalseg/train"
label_contrast_root_dir = "E:/run_dataset_totalseg/valid"
L0_label_dir = 'L0'
L1_label_dir = 'L1'
L1_L2_label_dir = 'L1_L2'
L2_label_dir = 'L2'
L2_L3_label_dir = 'L2_L3'
L3_label_dir = 'L3'
L3_L4_label_dir = 'L3_L4'
L4_label_dir = 'L4'
L4_L5_label_dir = 'L4_L5'
L5_label_dir = 'L5'
T12_label_dir = 'T12'
T12_L1_label_dir = 'T12_L1'
unlabel = -1

#train contrast
contrast_L0_dataset = MyData(contrast_root_dir,L0_label_dir,transform=contrast_transforms)
contrast_L1_dataset = MyData(contrast_root_dir,L1_label_dir,transform=contrast_transforms)
contrast_L1_L2_dataset = MyData(contrast_root_dir,L1_L2_label_dir,transform=contrast_transforms)
contrast_L2_dataset = MyData(contrast_root_dir,L2_label_dir,transform=contrast_transforms)
contrast_L2_L3_dataset = MyData(contrast_root_dir,L2_L3_label_dir,transform=contrast_transforms)
contrast_L3_dataset = MyData(contrast_root_dir,L3_label_dir,transform=contrast_transforms)
contrast_L3_L4_dataset = MyData(contrast_root_dir,L3_L4_label_dir,transform=contrast_transforms)
contrast_L4_dataset = MyData(contrast_root_dir,L4_label_dir,transform=contrast_transforms)
contrast_L4_L5_dataset = MyData(contrast_root_dir,L4_L5_label_dir,transform=contrast_transforms)
contrast_L5_dataset = MyData(contrast_root_dir,L5_label_dir,transform=contrast_transforms)
contrast_T12_dataset = MyData(contrast_root_dir,T12_label_dir,transform=contrast_transforms)
contrast_T12_L1_dataset = MyData(contrast_root_dir,T12_L1_label_dir,transform=contrast_transforms)

contrast_other_dataset = contrast_L0_dataset + contrast_L1_dataset + contrast_L1_L2_dataset + contrast_L2_dataset + contrast_L2_L3_dataset + contrast_L3_L4_dataset + contrast_L4_dataset + contrast_L4_L5_dataset + contrast_L5_dataset + contrast_T12_dataset + contrast_T12_L1_dataset
contrast_all_dataset = contrast_L3_dataset + contrast_other_dataset

#unlabeled_data
unlabeled_L0_dataset = unlabel_MyData(label_contrast_root_dir,L0_label_dir,transform=contrast_transforms)
unlabeled_L1_dataset = unlabel_MyData(label_contrast_root_dir,L1_label_dir,transform=contrast_transforms)
unlabeled_L1_L2_dataset = unlabel_MyData(label_contrast_root_dir,L1_L2_label_dir,transform=contrast_transforms)
unlabeled_L2_dataset = unlabel_MyData(label_contrast_root_dir,L2_label_dir,transform=contrast_transforms)
unlabeled_L2_L3_dataset = unlabel_MyData(label_contrast_root_dir,L2_L3_label_dir,transform=contrast_transforms)
unlabeled_L3_dataset = unlabel_MyData(label_contrast_root_dir,L3_label_dir,transform=contrast_transforms)
unlabeled_L3_L4_dataset = unlabel_MyData(label_contrast_root_dir,L3_L4_label_dir,transform=contrast_transforms)
unlabeled_L4_dataset = unlabel_MyData(label_contrast_root_dir,L4_label_dir,transform=contrast_transforms)
unlabeled_L4_L5_dataset = unlabel_MyData(label_contrast_root_dir,L4_L5_label_dir,transform=contrast_transforms)
unlabeled_L5_dataset = unlabel_MyData(label_contrast_root_dir,L5_label_dir,transform=contrast_transforms)
unlabeled_T12_dataset = unlabel_MyData(label_contrast_root_dir,T12_label_dir,transform=contrast_transforms)
unlabeled_T12_L1_dataset = unlabel_MyData(label_contrast_root_dir,T12_L1_label_dir,transform=contrast_transforms)

unlabeled_other_dataset = unlabeled_L0_dataset + unlabeled_L1_dataset + unlabeled_L1_L2_dataset + unlabeled_L2_dataset + unlabeled_L2_L3_dataset + unlabeled_L3_L4_dataset + unlabeled_L4_dataset + unlabeled_L4_L5_dataset + unlabeled_L5_dataset + unlabeled_T12_dataset + unlabeled_T12_L1_dataset
unlabeled_all_dataset = unlabeled_L3_dataset + unlabeled_other_dataset


# 定義simclr結構
## resnet18 + projectionhead(full + relu + full)

In [7]:
class SimCLR(pl.LightningModule):
    
    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500):
        super().__init__()#繼承原module的所有init
        self.save_hyperparameters()#記錄下每次訓練之後的參數，方便在多次訓練中去選出表現最好得模型
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # Base model f(.)
        self.convnet = torchvision.models.resnet18(num_classes=4*hidden_dim)  # Output of last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear 
        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), 
                                lr=self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50)
        return [optimizer], [lr_scheduler]
        
    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch#只有給圖片沒有給label
        imgs = torch.cat(imgs, dim=0)
        
        # Encode all images
        feats = self.convnet(imgs)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()
        
        # Logging loss
        self.log(mode+'_loss', nll)
        # Get ranking position of positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)], 
                             dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
        self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
        self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())
        
        return nll
        
    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')
        
    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

In [8]:
def train_simclr(batch_size, max_epochs=500, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, 'SimCLR'),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         benchmark=True,
                         profiler='simple',
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc_top5'),
                                    LearningRateMonitor('epoch')])
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, 'SimCLR.ckpt')
    if os.path.isfile(pretrained_filename):
        print(f'Found pretrained model at {pretrained_filename}, loading...')
        model = SimCLR.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        train_loader = data.DataLoader(unlabeled_all_dataset, batch_size=batch_size, shuffle=True, 
                                       drop_last=True, pin_memory=True, num_workers=4)
        val_loader = data.DataLoader(contrast_all_dataset, batch_size=batch_size, shuffle=False, 
                                     drop_last=False, pin_memory=True, num_workers=4)
        pl.seed_everything(42) # To be reproducable
        model = SimCLR(max_epochs=max_epochs, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = SimCLR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [9]:
simclr_model = train_simclr(batch_size=4096, 
                            hidden_dim=128, 
                            lr=5e-4, 
                            temperature=0.07, 
                            weight_decay=1e-4, 
                            max_epochs=500)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | convnet | ResNet | 11.5 M
-----------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.019    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

e:\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


In [None]:
import os
from PIL import Image
contrast_root_dir = "E:/run_dataset_totalseg/train"
label_dir = 'L3'
image_path = os.path.join(contrast_root_dir,label_dir)
img_list = os.listdir(image_path)
fullpath = os.path.join(image_path,img_list[0])
print("full path",fullpath)
img = Image.open(fullpath)
img.show()

full path E:/run_dataset_totalseg/train\L3\s0001_140_L3.png
