In [1]:
import pandas as pd
from sklearn.model_selection import KFold, train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch.nn as nn 

from tqdm import tqdm

from torchvision import transforms
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [2]:
train_labels = pd.read_csv('../data/train_labels.csv')
train_metadata = pd.read_csv('../data/train_metadata.csv')

train = train_metadata.merge(train_labels, on='filename', how='inner')

def process_age(age_str):
    age_str = age_str.replace('[', '').replace(']', '').split(':')
    return (int(age_str[0]) + int(age_str[1])) / 2

body_map = ['thigh', 'trunc', 'face', 'forearm', 'arm', 'leg', 'hand', 'foot', 'sole', 'finger', 'neck', 'toe', 'seat', 'scalp', 'nail','lower limb/hip', 'hand/foot/nail', 'head/neck', 'upper limb/shoulder', 'other']
def process_body(body):
    if body in ['thigh', 'trunc', 'face', 'forearm', 'arm', 'leg', 'hand', 'foot', 'sole', 'finger', 'neck', 'toe', 'seat', 'scalp', 'nail']:
        return body_map.index(body)
    else:
        return body_map.index('other')
   
mel_map = {'other': 0, 'YES': 1, 'NO': 2}

train.age = train.age.apply(lambda x: process_age(x)).astype(int)
train.body_site = train.body_site.replace('trunk', 'trunc') .fillna('other').apply(process_body).astype(int)
train.melanoma_history = train.melanoma_history.fillna('other').apply(lambda x: mel_map[x]).astype(int)

print(train.shape)

(1342, 14)


In [3]:
class VisioMel_Dataset(Dataset):
    
    def __init__(self, data):
        
        data = data.reset_index(drop=True)
        
        self.filenames, self.y = data['filename'], data['relapse']
        
        self.age = data['age'] / 100
        self.sex = data['sex'] - 1
        self.body_site = data['body_site']
        self.melanoma_history = data['melanoma_history']
        
        self.mean, self.std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
        self.eval_t = transforms.Compose([transforms.ToTensor(), 
                                          transforms.Normalize(mean = self.mean, std = self.std)#,
                                          #transforms.Resize((224,224))
                                        ])
        
        self.labels = torch.tensor(self.y.values, dtype = torch.float32)
        
        # To tensor labels
        self.age = torch.tensor(self.age.values, dtype = torch.float32)
        self.sex = torch.tensor(self.sex.values, dtype = torch.long)
        self.body_site = torch.tensor(self.body_site.values, dtype = torch.long)
        self.melanoma_history= torch.tensor(self.melanoma_history.values, dtype = torch.long)
        
        print(f'{self.labels.shape}')
    
    def __len__(self):
        return self.labels.shape[0]
    
    def __getitem__(self, index):
        image = Image.open('../data/images/'+self.filenames[index])
        image = self.eval_t(image)
        return image, self.age[index], self.sex[index], self.body_site[index], self.melanoma_history[index], self.labels[index]

In [4]:
#train, test = train_test_split(train, test_size=0.2, random_state=42)

train_dataset = VisioMel_Dataset(train)
#valid_dataset = VisioMel_Dataset(test)

train_dataloader = DataLoader(
                train_dataset,  
                sampler = RandomSampler(train_dataset),
                batch_size = 1
            )
'''
valid_dataloader = DataLoader(
            valid_dataset,  
            sampler = SequentialSampler(valid_dataset),
            batch_size = 1
        )
'''

torch.Size([1342])


'\nvalid_dataloader = DataLoader(\n            valid_dataset,  \n            sampler = SequentialSampler(valid_dataset),\n            batch_size = 1\n        )\n'

In [5]:
import sys;sys.path.append('../src/HIPT_4K')
from hipt_4k import HIPT_4K
from hipt_model_utils import get_vit256, get_vit4k

class VisioMel(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        pretrained_weights256 = '../src/HIPT_4K/Checkpoints/vit256_small_dino.pth'
        pretrained_weights4k = '../src/HIPT_4K/Checkpoints/vit4k_xs_dino.pth'
        self.device256 = torch.device('cuda:0')
        self.device4k = torch.device('cuda:0')

        ### ViT_256 + ViT_4K loaded into HIPT_4K API
        self.backbone = HIPT_4K(pretrained_weights256, pretrained_weights4k, self.device256, self.device4k)
        
        self.body = nn.Embedding(20, 11)
        self.sex = nn.Embedding(2, 3)
        self.mel = nn.Embedding(3, 5)
        
        self.act = nn.ReLU()
        
        self.cls1 = nn.Linear(212, 40)
        #self.cls2 = nn.Linear(100, 35)
        self.cls3 = nn.Linear(40, 10)
        self.cls4 = nn.Linear(10, 1)


    def forward(self, img, xage ,xsex, xbody, xmel):

        x = self.backbone(img) #.forward(x)
        x2 = self.sex(xsex)
        x3 = self.body(xbody)
        x4 = self.mel(xmel)
            
        x = torch.cat([x, xage, x2, x3, x4], dim=1)
        
        x = self.act(x)
        x = self.cls1(x)
        x = self.act(x)
        #x = self.cls2(x)
        #x = self.act(x)
        x = self.cls3(x)
        x = self.act(x)
        x = self.cls4(x)

        return x

In [6]:
device = torch.device('cuda:0')

#model = VisioMel()
#model = torch.nn.DataParallel(model, device_ids=[0, 1])
model = VisioMel().to(device)
model

Take key teacher in provided checkpoint dict
Pretrained weights found at ../src/HIPT_4K/Checkpoints/vit256_small_dino.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
# of Patches: 196
Take key teacher in provided checkpoint dict
Pretrained weights found at ../src/HIPT_4K/Checkpoints/vit4k_xs_dino.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


VisioMel(
  (backbone): HIPT_4K(
    (model256): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(

In [7]:
epochs = 20000

optimizer = torch.optim.SGD(model.parameters(), lr=5e-3)
#optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
criterion = nn.BCEWithLogitsLoss()

In [8]:
ckpt = torch.load('../artifacts/ckpt_full_5.pth')
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

In [9]:
with tqdm(total=epochs,leave=True) as pbar:
    for epoch_i in range(0, epochs):
        
        print(f'Epoch {epoch_i}')

        total_train_loss = 0
        model.train()

        optimizer.zero_grad()
        
        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):

            b_img = batch[0].to(device) # xage ,xsex, xbody, xmel
            b_age = batch[1].unsqueeze(0).to(device)
            b_sex = batch[2].to(device)
            b_body = batch[3].to(device)
            b_mel = batch[4].to(device)
            b_labels = batch[5].to(device)

            with torch.cuda.amp.autocast(enabled=False):
                b_logits = model(b_img, b_age, b_sex, b_body, b_mel).squeeze(-1)
                    
            loss = criterion(b_logits,b_labels)
            loss.backward()

            total_train_loss += loss.item()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()
            optimizer.zero_grad()
        
        avg_train_loss = (total_train_loss/len(train_dataloader))
        
        print(f'Train: loss {avg_train_loss:.5f}')
        '''
        model.eval()

        total_dev_loss = 0

        for step, batch in enumerate(valid_dataloader):

            b_img = batch[0].to(device) # xage ,xsex, xbody, xmel
            b_age = batch[1].unsqueeze(0).to(device)
            b_sex = batch[2].to(device)
            b_body = batch[3].to(device)
            b_mel = batch[4].to(device)
            b_labels = batch[5].to(device)

            with torch.cuda.amp.autocast(enabled=False):
                with torch.no_grad(): 
                    b_logits = model(b_img, b_age, b_sex, b_body, b_mel).squeeze(-1)

            loss = criterion(b_logits,b_labels)
            total_dev_loss += loss.item()
        
        avg_dev_loss = (total_dev_loss/len(valid_dataloader))
s
        print(f'Dev: loss {avg_dev_loss:.5f}')
        '''
        torch.save({
            'epoch': epoch_i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'../artifacts/ckpt_full_{epoch_i}.pth')

        pbar.update()

  0%|          | 0/20000 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 1342/1342 [32:11<00:00,  1.44s/it]


Train: loss 0.34846


  0%|          | 1/20000 [32:11<10728:40:50, 1931.26s/it]

Epoch 1


100%|██████████| 1342/1342 [31:15<00:00,  1.40s/it]


Train: loss 0.34206


  0%|          | 2/20000 [1:03:26<10545:50:53, 1898.44s/it]

Epoch 2


100%|██████████| 1342/1342 [31:14<00:00,  1.40s/it]


Train: loss 0.34039


  0%|          | 3/20000 [1:34:41<10486:28:52, 1887.85s/it]

Epoch 3


 88%|████████▊ | 1184/1342 [27:28<03:39,  1.39s/it]
  0%|          | 3/20000 [2:02:10<13572:45:28, 2443.46s/it]


KeyboardInterrupt: 