# Load image data

In [5]:
import pandas as pd
import numpy as np

from PIL import Image
from transformers import AutoImageProcessor, ViTModel

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(13)

class_ind = {'CC':0, 'EC':1, 'LGSC':2, 'HGSC':3, 'MC':4}

class OvarianImages(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_metadata = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_metadata)

    def __getitem__(self, idx):
        sample_idx = idx % len(self.img_metadata) 
        sample = self.img_metadata.iloc[sample_idx, 0]

        label = torch.tensor(class_ind[self.img_metadata.iloc[sample_idx, 1]]).to(device)  

        image_patches = []
        for i in range(100):
            img_path = self.img_dir + f'sample_{sample}' + f'/{sample}_{i}.png'
            patch = torch.tensor(np.asarray(Image.open(img_path))[13:237, 13:237].T, dtype=torch.float32)
            image_patches.append(patch)
        image_patches = torch.stack(image_patches, dim=0).to(device)
        return image_patches, label

In [3]:
# Create Dataset
metadata = "/scratch1/yuqiuwan/CSCI567/train.csv"
image_dir = "/scratch1/yuqiuwan/CSCI567/preprocess_images/"
imgs = OvarianImages(metadata, image_dir)
train_set, test_set = torch.utils.data.random_split(imgs, [0.8, 0.2])

# Set up Base model architecture

In [4]:
# Base model
class BaselineModel(nn.Module):
    def __init__(self, FeatureExtractor, d_embed=768, n_classes=5):
        super().__init__()
        self.image_encoder = FeatureExtractor
        self.image_proj = nn.Linear(d_embed, n_classes)
        self.n_classes = n_classes

    def forward(self, *args, **kwargs):
        outputs = self.image_encoder(*args, **kwargs)
        features = outputs.last_hidden_state[:, 0, :] 

        logits = self.image_proj(features).view(-1, 100, self.n_classes)
        logits = torch.sum(logits, axis=1)

        return logits

# Set up training procedure

In [64]:
def train_loop(dataloader, model, loss_fn, optimizer, image_processor=None):
    size = len(dataloader.dataset)
    model.train()
    for batch, (imgs, labels) in enumerate(dataloader): 
        imgs = imgs.view(-1, 3, 224, 224)
        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            # Compute prediction and loss
            logits = model(imgs['pixel_values'])
        else:
            # Compute prediction and loss
            logits = model(imgs)

        loss = loss_fn(logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc = torch.sum(torch.argmax(logits, axis=1) == labels) / logits.shape[0]

        if batch % 1 == 0:
            loss, current = loss.item(), (batch + 1) * batch_size
            print(f"loss: {loss:>7f}; train_acc: {train_acc:>7f}  [{current:>5d}/{size:>5d}]")

######################## Train Base Model ########################
batch_size = 1
dropout = 0.0
learning_rate = 10**(-3)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# load phikon
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
extractor = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False)
extractor.eval()

# Initialize base model
base_model = BaselineModel(extractor, d_embed=768, n_classes=5).to(device)
# Freeze the pre-trained model
for p in base_model.image_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(base_model.parameters(), lr=learning_rate)

Some weights of the model checkpoint at owkin/phikon were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
epochs = 2
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, base_model, F.cross_entropy, optimizer, image_processor=image_processor)
print("Done!")

In [49]:
torch.save(base_model.state_dict(), '/scratch1/yuqiuwan/CSCI567/phikon_model_state_dict.pt')

# Test model

In [None]:
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
extractor = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False)
extractor.eval()

# Initialize base model
base_model = BaselineModel(extractor, d_embed=768, n_classes=5)
base_model.load_state_dict(torch.load('/scratch1/yuqiuwan/CSCI567/phikon_model_state_dict2.pt'))
base_model.eval()
base_model.to(device)

In [9]:
test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

def test_loop(dataloader, model, image_processor=None):
    size = len(dataloader.dataset)
    model.eval()
    correct_num = 0
    for batch, (imgs, labels) in enumerate(dataloader): 
        imgs = imgs.view(-1, 3, 224, 224)
        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            logits = model(**imgs)
        else:
            logits = model(imgs)

        correct_num += torch.sum(torch.argmax(logits, axis=1) == labels) 
        print(batch, 'correct_num:', correct_num)
    
    test_acc = correct_num  / size
    print('Test_accuracy:', test_acc)

        
test_loop(test_dataloader, base_model, image_processor)

0 correct_num: tensor(1, device='cuda:0')
1 correct_num: tensor(2, device='cuda:0')
2 correct_num: tensor(3, device='cuda:0')
3 correct_num: tensor(4, device='cuda:0')
4 correct_num: tensor(5, device='cuda:0')
5 correct_num: tensor(6, device='cuda:0')
6 correct_num: tensor(7, device='cuda:0')
7 correct_num: tensor(8, device='cuda:0')
8 correct_num: tensor(9, device='cuda:0')
9 correct_num: tensor(10, device='cuda:0')
10 correct_num: tensor(11, device='cuda:0')
11 correct_num: tensor(12, device='cuda:0')
12 correct_num: tensor(13, device='cuda:0')
13 correct_num: tensor(13, device='cuda:0')
14 correct_num: tensor(14, device='cuda:0')
15 correct_num: tensor(15, device='cuda:0')
16 correct_num: tensor(16, device='cuda:0')
17 correct_num: tensor(17, device='cuda:0')
18 correct_num: tensor(18, device='cuda:0')
19 correct_num: tensor(19, device='cuda:0')
20 correct_num: tensor(20, device='cuda:0')
21 correct_num: tensor(21, device='cuda:0')
22 correct_num: tensor(22, device='cuda:0')
23 corr