In [123]:
import os
import random
from sklearn.model_selection import train_test_split
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import EfficientNet
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
root = r'/home/woody/iwso/iwso092h/mad_ucb_kaggle'
train_images = r'train_thumbnails'
image_folder = r'/home/woody/iwso/iwso092h/mad_ucb_kaggle/train_thumbnails'
train_df = pd.read_csv(os.path.join(root,'data','train.csv'))
train_wo_tma = train_df[train_df.is_tma==False]
train_enc = pd.get_dummies(train_wo_tma, columns=['label'])
train_enc[['label_CC',
       'label_EC', 'label_HGSC', 'label_LGSC', 'label_MC']] = train_enc[['label_CC',
       'label_EC', 'label_HGSC', 'label_LGSC', 'label_MC']].astype(int)
train_split, val_split = train_test_split(train_enc, test_size=0.2, random_state=17)
train_split = train_split.reset_index(drop=True)
val_split = val_split.reset_index(drop=True)

In [59]:
def get_tiles(img, tile_size=256, n_tiles=30, mode=0):
    h, w, c = img.shape
    pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
    pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

    img = np.pad(
        img,
        [[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2], [0, 0]],
        constant_values=0,
    )
    img = img.reshape(
        img.shape[0] // tile_size, tile_size, img.shape[1] // tile_size, tile_size, 3
    )
    img = img.transpose(0, 2, 1, 3, 4).reshape(-1, tile_size, tile_size, 3)
    
    idxs = np.argsort(img.reshape(img.shape[0], -1).sum(-1))
    if len(img) < n_tiles:
        img = np.pad(
            img, [[0, n_tiles - len(img)], [0,0], [0,0], [0,0]], constant_values=255
        )
    # idxs = np.argsort(-img.reshape(img.shape[0], -1).sum(-1))[:n_tiles]
    # print(type(idxs))
    if idxs.shape[0]>n_tiles:
        idxs = idxs[-n_tiles:]
    img = img[idxs]
    
    return img

def concat_tiles(tiles, n_tiles, image_size):
    idxes = list(range(n_tiles))
    
    n_row_tiles = int(np.sqrt(n_tiles))
    img = np.zeros(
        (image_size*n_row_tiles, image_size*n_row_tiles, 3), dtype="uint8"
    )
    
    for h in range(n_row_tiles):
        for w in range(n_row_tiles):
            i = h * n_row_tiles + w
            if len(tiles) > idxes[i]:
                this_img = tiles[idxes[i]]
            else:
                this_img = np.ones((image_size, image_size, 3), dtype="uint8") * 255
                
            h1 = h * image_size
            w1 = w * image_size
            img[h1 : h1 + image_size, w1 : w1 + image_size] = this_img
    return img

def sort_tiles_by_intensity(tiles):
    intensities = np.mean(tiles, axis=(1, 2, 3))  # Calculate mean intensity for each tile
    
    sorted_indices = np.argsort(-intensities)
    
    sorted_tiles = tiles[sorted_indices]
    sorted_intensities = intensities[sorted_indices]
    
    return sorted_tiles, sorted_intensities, sorted_indices

In [4]:
import albumentations as A
import numpy as np
import torch

def to_tensor(x):
    x = x.astype("float32") / 255

    return torch.from_numpy(x).permute(2, 0, 1)

In [114]:
class CustomCancerDataset(Dataset):
    def __init__(self, metadata_df, image_folder, transform=None):
        self.metadata_df = metadata_df
        self.image_folder = image_folder
        self.transform = transform  # Use the provided transform

    def __len__(self):
        return len(self.metadata_df)
    
    def __getitem__(self, idx):
        image_ids = self.metadata_df.image_id[idx]  
        image_name = os.path.join(self.image_folder, "{}_thumbnail.png".format(image_ids))
        image = Image.open(image_name)
        img = get_tiles(
            np.array(image),
            mode=0, n_tiles= 64
        )
        
        sorted_img = sort_tiles_by_intensity(img)
        img = concat_tiles(
            img, 64, 256
        )

        # img = to_tensor(img)
        img = Image.fromarray(img)
        
        if self.transform:
            img = self.transform(img)
        
        label_CC  = self.metadata_df.label_CC[idx].astype(int)  
        label_EC  = self.metadata_df.label_EC[idx].astype(int)    
        label_HGSC  = self.metadata_df.label_HGSC[idx].astype(int)    
        label_LGSC  = self.metadata_df.label_LGSC[idx].astype(int)    
        label_MC  = self.metadata_df.label_MC[idx].astype(int)

        return img, torch.tensor([label_CC, label_EC, label_HGSC, label_LGSC, label_MC], dtype=torch.float)

In [115]:
data_transforms = transforms.Compose([
    
    transforms.Resize(224),
    transforms.ToTensor(),
])

train_ucb_dataset = CustomCancerDataset(train_split, image_folder='/home/woody/iwso/iwso092h/mad_ucb_kaggle/train_thumbnails',
                                        transform=data_transforms)
val_ucb_dataset = CustomCancerDataset(val_split, image_folder='/home/woody/iwso/iwso092h/mad_ucb_kaggle/train_thumbnails',
                                      transform=data_transforms)

batch_size = 32
train_loader = DataLoader(train_ucb_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ucb_dataset, batch_size=batch_size)

In [78]:
def train(net, loss, train_dataloader, valid_dataloader, device, batch_size, num_epoch, lr, lr_min, optim='sgd', init=True, scheduler_type='Cosine'):
    def init_xavier(m): 
        #if type(m) == nn.Linear or type(m) == nn.Conv2d:
        if type(m) == nn.Linear:
            nn.init.xavier_normal_(m.weight)

    if init:
        net.apply(init_xavier)

    print('training on:', device)
    net.to(device)
    
    if optim == 'sgd': 
        optimizer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad), lr=lr,
                                    weight_decay=0)
    elif optim == 'adam':
        optimizer = torch.optim.Adam((param for param in net.parameters() if param.requires_grad), lr=lr,
                                     weight_decay=0)
    elif optim == 'adamW':
        optimizer = torch.optim.AdamW((param for param in net.parameters() if param.requires_grad), lr=lr,
                                      weight_decay=0)
    elif optim == 'ranger':
        optimizer = Ranger((param for param in net.parameters() if param.requires_grad), lr=lr,
                           weight_decay=0)
    if scheduler_type == 'Cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epoch, eta_min=lr_min)
    
    train_losses = []
    train_acces = []
    eval_acces = []
    best_acc = 0.0
    #Train
    for epoch in range(num_epoch):

        print("——————Start of training round {}——————".format(epoch + 1))

        
        net.train()
        train_acc = 0
        for batch in tqdm(train_dataloader, desc='Train'):
            imgs, targets = batch
            imgs = imgs.to(device)
            #targets = torch.cat(targets, dim = 0)
            targets = targets.to(device)
            output = net(imgs)

            Loss = loss(output, targets)
          
            optimizer.zero_grad()
            Loss.backward()
            optimizer.step()

            _, pred = output.max(1)
            num_correct = (pred == targets).sum().item()
            acc = num_correct / (batch_size)
            train_acc += acc
        scheduler.step()
        print("epoch: {}, Loss: {}, Acc: {}".format(epoch, Loss.item(), train_acc / len(train_dataloader)))
        train_acces.append(train_acc / len(train_dataloader))
        train_losses.append(Loss.item())

        
        net.eval()
        eval_loss = 0
        eval_acc = 0
        with torch.no_grad():
            for imgs, targets in valid_dataloader:
                imgs = imgs.to(device)
                #targets = torch.cat(targets, dim = 0)
                targets = targets.to(device)
                output = net(imgs)
                Loss = loss(output, targets)
                _, pred = output.max(1)
                num_correct = (pred == targets).sum().item()
                eval_loss += Loss
                acc = num_correct / imgs.shape[0]
                eval_acc += acc

            eval_losses = eval_loss / (len(valid_dataloader))
            eval_acc = eval_acc / (len(valid_dataloader))
            if eval_acc > best_acc:
                best_acc = eval_acc
                torch.save(net.state_dict(),'best_acc.pth')
            eval_acces.append(eval_acc)
            print("Loss on the overall validation set: {}".format(eval_losses))
            print("Correctness on the overall validation set: {}".format(eval_acc))
    return train_losses, train_acces, eval_acces

In [118]:
def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs=10, save_path='model'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc='training:', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.float())  # Convert labels to float
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            all_preds = []
            all_labels = []
            for inputs, labels in tqdm(val_loader, desc='validating', leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)  # Ensure that your loss function matches the task
                val_loss += loss.item()

                # Calculate predictions
                _, preds = torch.max(outputs, 1)  # Get the predicted class
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            # Calculate accuracy or other relevant metrics
            # val_accuracy = accuracy_score(all_labels, all_preds)
                    # val_balanced_accuracy = accuracy_score(all_labels, all_preds)

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {running_loss/len(train_loader):.4f} "
              f"Val Loss: {val_loss/len(val_loader):.4f} "
            #   f"Val Balanced Accuracy: {val_accuracy:.4f}"
                )

        scheduler.step()
        
        if (epoch + 1) % 3 == 0:
            torch.save(model.state_dict(), f'{save_path}epoch_{epoch+1}_mobile2.pth')


In [87]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = nn.CrossEntropyLoss()

In [71]:
from transformers import AutoImageProcessor, ViTForImageClassification

2023-12-02 00:43:55.361113: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-02 00:43:59.092694: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [120]:
import pytorch_lightning as pl
import timm

class ViT(pl.LightningModule):
    def __init__(self, num_classes, image_size=224, backbone='vit_base_patch16_224', finetune_layer=True):
        super().__init__()
        self.model = timm.create_model(backbone, pretrained=True)
        
        # Freeze all layers except the last one for transfer learning
        if finetune_layer:
            for param in self.model.parameters():
                param.requires_grad = False
            self.model.head = nn.Identity()
            self.finetune_layer = nn.Linear(self.model.head.in_features, num_classes)
        else:
            self.model.head = nn.Linear(self.model.head.in_features, num_classes)
            self.finetune_layer = None

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        if self.finetune_layer is not None:
            logits = self.finetune_layer(logits)

        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        if self.finetune_layer is not None:
            logits = self.finetune_layer(logits)

        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        if self.finetune_layer is not None:
            parameters = list(self.model.parameters()) + list(self.finetune_layer.parameters())
        else:
            parameters = self.parameters()

        optimizer = torch.optim.Adam(parameters, lr=1e-4)
        return optimizer


In [124]:
model = ViT(num_classes=5, finetune_layer=True)

trainer = pl.Trainer(max_epochs=10)  # Adjust max_epochs and gpus as needed
trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | model | VisionTransformer | 85.8 M
--------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.210   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/woody/iwso/iwso092h/miniconda/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/woody/iwso/iwso092h/miniconda/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/woody/iwso/iwso092h/miniconda/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
