In [19]:
import argparse
import json
import pathlib

import pandas as pd
import timm
import torch
import torchvision.transforms as transforms
import tqdm
from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder

from sklearn.preprocessing import LabelEncoder

from DINO.evaluation import compute_embedding, compute_knn
from DINO.utils import DataAugmentation, Head, Loss, MultiCropWrapper, clip_gradients

import cv2

In [4]:
class args:
    batch_size = 4
    device = "cuda"
    logging_freq = 200
    momentum_teacher = 0.9995
    n_crops = 4
    n_epochs = 100
    out_dim = 1024
    tensorboard_dir = "logs"
    clip_grad = 2.0
    norm_last_layer = True
    batch_size_eval = 64
    teacher_temp = 0.04
    student_temp = 0.1
    pretrained = True
    weight_decay = 0.4

args = args

In [5]:
vit_name, dim = "vit_deit_small_patch16_224", 384
path_dataset_train = pathlib.Path("data/imagenette2-320/train")
path_dataset_val = pathlib.Path("data/imagenette2-320/val")
path_labels = pathlib.Path("data/imagenette_labels.json")

logging_path = pathlib.Path(args.tensorboard_dir)
device = torch.device(args.device)

n_workers = os.cpu_count()

In [6]:
data_path = "/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/"
tmp = [data_path + o for o in os.listdir(data_path)]
tmp

['/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/test_path_list.pkl',
 '/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/val_path_list.pkl',
 '/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/test_label_list.pkl',
 '/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/train_label_list.pkl',
 '/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/val_label_list.pkl',
 '/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/train_path_list.pkl']

In [7]:
train_path_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/train_path_list.pkl')
val_path_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/val_path_list.pkl')
test_path_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/test_path_list.pkl')

train_label_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/train_label_list.pkl')
val_label_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/val_label_list.pkl')
test_label_list = pd.read_pickle('/home/tadokororyu/Research/Contrastive_learning/code/pkl_data/baseline_cleaned/test_label_list.pkl')

In [8]:
print(len(train_path_list),len(val_path_list),len(test_path_list))

47822 7141 12638


In [25]:
train_path_list

['/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/7013_2011_02_01_TB4_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/22306_2018_04_26_TB1_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/2824_2007_09_26_TB4_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/7362_2011_05_10_TB1_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/7562_2011_07_01_TB1_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/1376_2010_12_24_KO2_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG/basal cell carcinoma (rodent ulcer_basiloma)/17265_2016_11_01_TB4_0.jpg',
 '/home/tadokororyu/NAS_alma/ClinicalImages/Shido/NSDD_59/NG

In [9]:
le = LabelEncoder()
le.fit(train_label_list)

train_labels=le.transform(train_label_list)
val_labels=le.transform(val_label_list)
test_labels=le.transform(test_label_list)

In [10]:
train_labels

array([ 5,  5,  5, ..., 25, 25, 25])

In [21]:
class ImageFolder(Dataset):
    def __init__(self, file_list, label_list , transform=None):
        self.file_list = file_list
        self.label_list = label_list
        self.transforms = transforms
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = cv2.imread(img_path)
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.label_list[index]
        
        if self.transforms:
            img = self.transforms(img)
            
        return img, label

In [22]:
transform_aug = DataAugmentation(size=224, n_local_crops=args.n_crops - 2)
transform_plain = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224)),
    ]
)

dataset_train_aug = ImageFolder(train_path_list, train_labels, transform=transform_aug)
dataset_train_plain = ImageFolder(train_path_list, train_labels, transform=transform_plain)
dataset_val_plain = ImageFolder(val_path_list, val_labels, transform=transform_plain)


data_loader_train_aug = DataLoader(
    dataset_train_aug,
    batch_size=args.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=n_workers,
    pin_memory=True,
)
data_loader_train_plain = DataLoader(
    dataset_train_plain,
    batch_size=args.batch_size_eval,
    drop_last=False,
    num_workers=n_workers,
)
data_loader_val_plain = DataLoader(
    dataset_val_plain,
    batch_size=args.batch_size_eval,
    drop_last=False,
    num_workers=n_workers,
)
data_loader_val_plain_subset = DataLoader(
    dataset_val_plain,
    batch_size=args.batch_size_eval,
    drop_last=False,
    sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain), 50))),
    num_workers=n_workers,
)




In [23]:
student_vit = timm.create_model(vit_name, pretrained=args.pretrained)
teacher_vit = timm.create_model(vit_name, pretrained=args.pretrained)

student = MultiCropWrapper(
    student_vit,
    Head(
        dim,
        args.out_dim,
        norm_last_layer=args.norm_last_layer,
    ),
)
teacher = MultiCropWrapper(teacher_vit, Head(dim, args.out_dim))
student, teacher = student.to(device), teacher.to(device)

teacher.load_state_dict(student.state_dict())

for p in teacher.parameters():
    p.requires_grad = False

In [17]:
# Loss related
loss_inst = Loss(
    args.out_dim,
    teacher_temp=args.teacher_temp,
    student_temp=args.student_temp,
).to(device)
lr = 0.0005 * args.batch_size / 256
optimizer = torch.optim.AdamW(
    student.parameters(),
    lr=lr,
    weight_decay=args.weight_decay,
)

In [24]:
# Training loop
n_batches = len(dataset_train_aug) // args.batch_size
best_acc = 0
n_steps = 0

for e in range(args.n_epochs):
    for i, (images, _) in tqdm.tqdm(
        enumerate(data_loader_train_aug), total=n_batches
    ):
        if n_steps % args.logging_freq == 0:
            student.eval()

            # Embedding
            embs, imgs, labels_ = compute_embedding(
                student.backbone,
                data_loader_val_plain_subset,
            )
            writer.add_embedding(
                embs,
                metadata=[label_mapping[l] for l in labels_],
                label_img=imgs,
                global_step=n_steps,
                tag="embeddings",
            )

            # KNN
            current_acc = compute_knn(
                student.backbone,
                data_loader_train_plain,
                data_loader_val_plain,
            )
            writer.add_scalar("knn-accuracy", current_acc, n_steps)
            if current_acc > best_acc:
                torch.save(student, logging_path / "best_model.pth")
                best_acc = current_acc

            student.train()

        images = [img.to(device) for img in images]

        teacher_output = teacher(images[:2])
        student_output = student(images)

        loss = loss_inst(student_output, teacher_output)

        optimizer.zero_grad()
        loss.backward()
        clip_gradients(student, args.clip_grad)
        optimizer.step()

        with torch.no_grad():
            for student_ps, teacher_ps in zip(
                student.parameters(), teacher.parameters()
            ):
                teacher_ps.data.mul_(args.momentum_teacher)
                teacher_ps.data.add_(
                    (1 - args.momentum_teacher) * student_ps.detach().data
                )

        writer.add_scalar("train_loss", loss, n_steps)

        n_steps += 1

  0%|          | 0/11955 [00:01<?, ?it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/tadokororyu/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/tadokororyu/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tadokororyu/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-21-e24fedcc583d>", line 18, in __getitem__
    img = self.transforms(img)
TypeError: 'module' object is not callable
