<a href="https://colab.research.google.com/github/ackrds/SSLHistopathology/blob/main/main_dino_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tifffile
!pip install timm

import os
import tqdm
import glob
import tifffile

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import Subset
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T

import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/gdrive')

%cd /content/gdrive/MyDrive/Transformers/dino

import utils
import vision_transformer as vits
from vision_transformer import DINOHead


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BASE_DIR = '/content/gdrive/MyDrive/Transformers/dino'

Collecting timm
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.16.4 safetensors-0.3.1 timm-0.9.2
Mounted at /content/gdrive
/content/gdrive/MyDrive/Transformers/dino


In [2]:
!python main_dino.py --arch vit_small --patch_size 16 --out_dim 65536  --momentum_teacher 0.996 --use_bn_in_head False --warmup_teacher_temp 0.04 --teacher_temp 0.04 --warmup_teacher_temp_epochs 0 --weight_decay 0.04 --weight_decay_end 0.4 --clip_grad 3.0 --batch_size_per_gpu 64 --epochs 100 --freeze_last_layer 1 --lr 0.0005 --warmup_epochs 10 --min_lr 1e-6 --optimizer adamw --drop_path_rate 0.1 --global_crops_scale 0.4 1.0 --local_crops_number 8 --local_crops_scale 0.05 0.4  --output_dir './logs/' --saveckp_freq 20 --seed 0 --num_workers 1

Traceback (most recent call last):
  File "/content/gdrive/MyDrive/Transformers/dino/main_dino.py", line 25, in <module>
    import torch
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1465, in <module>
    from . import _meta_registrations
  File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 7, in <module>
    from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
  File "/usr/local/lib/python3.10/dist-packages/torch/_decomp/__init__.py", line 169, in <module>
    import torch._decomp.decompositions
  File "/usr/local/lib/python3.10/dist-packages/torch/_decomp/decompositions.py", line 10, in <module>
    import torch._prims as prims
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims/__init__.py", line 33, in <module>
    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/__init__.py", line 3, in <module>
 

In [2]:
train_dir = '/content/gdrive/MyDrive/ADPathology/Data/crc-100k/NCT-CRC-HE100K'
val_dir = '/content/gdrive/MyDrive/ADPathology/Data/crc-100k/CRC-VAL-HE-7K'

classes = os.listdir(train_dir)[0:]

train_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])

val_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

class Dataset(data.Dataset):

  def __init__(self, data_dir, transforms, classes=classes):
    self.data_dir = data_dir
    self.data_files = glob.glob(data_dir + '/*/*.tif', recursive=True)
    self.classes = classes
    self.transforms = transforms

  def __len__(self):
    return len(self.data_files)

  def __getitem__(self, index):

    self.img_dir = self.data_files[index]
    self.label = self.img_dir.split('/')[-1].split('-')[0]
    self.label = self.classes.index(self.label)
    self.img = tifffile.imread(self.img_dir)
    if self.transforms != None:
      self.img = self.transforms(self.img)


    return self.img, self.label

train_dataset = Dataset(data_dir=train_dir, transforms=train_transforms)
train_dataset1 = Subset(train_dataset, [i for i in range(32000,48000)])

val_dataset = Dataset(data_dir=val_dir, transforms=val_transforms)


In [3]:
arch = 'vit_small'
patch_size = 16
drop_path_rate = 0.1
out_dim = 65536

student = vits.__dict__[arch](
    patch_size=patch_size,
    drop_path_rate=drop_path_rate,
)
embed_dim = student.embed_dim

student = utils.MultiCropWrapper(student,
    DINOHead(
    embed_dim,
    out_dim = out_dim,
    use_bn=False,
    norm_last_layer=True,
))

ckpt_path = BASE_DIR + '/logs/checkpoint0000.pth'
state_dict = torch.load(ckpt_path, map_location='cpu')['student']
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
_, _ = student.load_state_dict(state_dict, strict=False)



In [4]:
def extract_embeddings(data_loader):
  embeddings, labels = [], []
  student.to(device)
  for batch_imgs, batch_labels in tqdm.tqdm(data_loader):
      batch_embeddings = student(batch_imgs.to(device))
      embeddings.append(batch_embeddings.detach().cpu())
      labels.append(batch_labels)

  embeddings = torch.cat(embeddings, dim=0)
  labels = torch.cat(labels, dim=0)
  return data.TensorDataset(embeddings, labels)

train_loader1 = data.DataLoader(train_dataset1, batch_size=64, shuffle=True, drop_last=False)
val_loader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=False)


In [None]:
train_embeddings1 = extract_embeddings(train_loader1)
torch.save(train_embeddings1, BASE_DIR + '/train_embeddings3.pt')

# val_embeddings = extract_embeddings(val_loader)
# torch.save(val_embeddings, BASE_DIR + '/val_embeddings.pt')

  3%|▎         | 8/250 [06:49<3:20:07, 49.62s/it]

In [None]:
class Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


batch_size = 64
train_loader = data.DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)

input_dim = out_dim
hidden_dim = 300000
output_dim = 10  # Number of classes in MNIST dataset


model = Classifier(input_dim, hidden_dim, output_dim)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f'Train Epoch: {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

# Evaluation loop
# model.eval()  # Set model to evaluation mode
# correct = 0
# total = 0

# with torch.no_grad():
#     for images, labels in test_loader:
#         images = images.to(device)
#         labels = labels.to(device)

#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

# accuracy = correct / total
print(f'Test Accuracy: {accuracy:.4f}')
