In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
cd /content/drive/MyDrive

/content/drive/MyDrive


In [3]:
############################
#Reference: <https://www.kaggle.com/code/abhinand05/vision-transformer-vit-tutorial-baseline/notebook>
############################

#install TPU dependencies
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py #transfer data to the notebook
!python pytorch-xla-env-setup.py --version 1.7 #get and setup torch_xla version
!pip install timm 

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  6034  100  6034    0     0  18915      0 --:--:-- --:--:-- --:--:-- 18915
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-1.7 ...
Found existing installation: torch 1.12.1+cu113
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting cloud-tpu-client
  Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)
Collecting google-api-python-client==1.8.0
  Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 3.0 MB/s 
Uninstalling torch-1.12.1+cu113:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 1.12.11
    Uninstalling google-api-python-client-1.12.1

In [2]:
#import library
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use("ggplot")

import torch
import torch.nn as nn
import torchvision.transforms as transforms

import albumentations

import torch_xla #to connect notebook to use Cloud TPU device
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

import timm #to collect newest computer vision model

import gc #garbage collector
import os #operating system
import time
import random #generate random number
from datetime import datetime

from PIL import Image
from tqdm.notebook import tqdm #to create progress bar
from sklearn import model_selection, metrics

# For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

print("torchversion:",torch.__version__)



torchversion: 1.7.0a0+7e71a98


In [3]:
############################
#Coded by Ng Jiun Shen
############################

# model specific global variables
Set = {
    'seed': 3074,
    'model_arch': 'vit_base_patch16_224',
    'img_size': 224,
    'epochs': 10,
    'train_bs': 16,
    'valid_bs': 16,
    'lr': 2e-05,
}

In [7]:
import zipfile
with zipfile.ZipFile('/content/drive/MyDrive/UCCD3074/Asm2/cassava-leaf-disease-classification.zip', 'r') as zip_ref:
    zip_ref.extractall('/content')

In [4]:
############################
#Coded by Ng Jiun Shen
############################

#read file
df = pd.read_csv('/content/cassava-leaf-disease-classification/train.csv')

#check success loaded
print("top 5 records\n",df.head()) 
print("\nlast 5 records\n",df.tail()) 

top 5 records
          image_id  label
0  1000015157.jpg      0
1  1000201771.jpg      3
2   100042118.jpg      1
3  1000723321.jpg      1
4  1000812911.jpg      3

last 5 records
             image_id  label
21392  999068805.jpg      3
21393  999329392.jpg      3
21394  999474432.jpg      1
21395  999616605.jpg      4
21396  999998473.jpg      4


In [5]:
############################
#Coded by Leong Wai Yin
############################

#Split into train,valid,test set
df_train = df.sample(frac=0.7, random_state=Set['seed'])
val_test = df.loc[~df.index.isin(df_train.index)]
df_test = val_test.sample(frac=0.5, random_state=Set['seed'])
df_valid = val_test.loc[~val_test.index.isin(df_test.index)]
print("dataset's length is",len(df))
print("trainset's length is",len(df_train))
print("validset's length is",len(df_valid))
print("testset's length is",len(df_test))

dataset's length is 21397
trainset's length is 14978
validset's length is 3209
testset's length is 3210


In [6]:
############################
#Reference: <https://www.kaggle.com/code/abhinand05/vision-transformer-vit-tutorial-baseline/notebook>
############################

class CassavaDataset(torch.utils.data.Dataset): #class for dataset
    def __init__(self, df, data_path="/content/cassava-leaf-disease-classification", mode="train", transforms=None):
        super().__init__()
        self.df_data = df.values
        self.data_path = data_path
        self.transforms = transforms
        self.mode = mode
        self.data_dir = "train_images" if mode == "train" else "test_images"

    def __len__(self):
        return len(self.df_data)

    def __getitem__(self, index):
        img_name, label = self.df_data[index] #assign index to each image
        img_path = os.path.join(self.data_path, self.data_dir, img_name)
        img = Image.open(img_path).convert("RGB")

        if self.transforms is not None:
            image = self.transforms(img)

        return image, label

In [7]:
############################
#Adapted from <https://www.kaggle.com/code/abhinand05/vision-transformer-vit-tutorial-baseline/notebook>
############################

# create image augmentations
transforms_valid = transforms.Compose( #no augmentation in valid set
    [
        transforms.Resize((Set['img_size'], Set['img_size'])),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

In [8]:
############################
#Reference: <https://www.kaggle.com/code/abhinand05/vision-transformer-vit-tutorial-baseline/notebook>
############################

class ViTBase16(nn.Module): #class for VIT module
    def __init__(self, n_classes, pretrained=False):

        super(ViTBase16, self).__init__()

        self.model = timm.create_model(Set['model_arch'], pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load("../Assignment2/vit-base-models-pretrained-pytorch/jx_vit_base_p16_224-80ecf9dd.pth"))

        self.model.head = nn.Linear(self.model.head.in_features, n_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def inference(self, test_loader, device):
        logits = []
        self.model.eval()
        for data, target in test_loader:
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()
            elif device.type == "xla":
                data = data.to(device, dtype=torch.float32)
                target = target.to(device, dtype=torch.int64)

            with torch.no_grad():
                # forward pass: compute predicted outputs by passing inputs to the model
                output = self.model(data)
                # calculate the batch loss
                logits.append(output.detach().cpu())
        probs = torch.sigmoid(torch.cat(logits)).numpy().squeeze()
               
        return probs
          

In [9]:
############################
#Reference: <https://www.kaggle.com/code/abhinand05/vision-transformer-vit-tutorial-baseline/notebook>
############################
valid_dataset = CassavaDataset(df_test, transforms=transforms_valid)

valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )

valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=Set['valid_bs'],
        sampler=valid_sampler,
        num_workers=8,
    )

model = ViTBase16(n_classes=5)
device = xm.xla_device()
model.to(device)
state_dict = torch.load('/content/drive/MyDrive/model_5e_20220903-1457.pth')
model.load_state_dict(state_dict, strict=True)
para_valid_loader = pl.ParallelLoader(valid_loader, [device])
probs = model.inference(para_valid_loader.per_device_loader(device), device)
probs

array([[0.24364243, 0.15507847, 0.15002882, 0.6959583 , 0.9890131 ],
       [0.02843603, 0.02479816, 0.31237   , 0.9905874 , 0.32167307],
       [0.6959583 , 0.05749328, 0.19559409, 0.23231015, 0.9046505 ],
       ...,
       [0.03676946, 0.01542455, 0.15713686, 0.9879462 , 0.7879312 ],
       [0.13386749, 0.20307462, 0.02887091, 0.99444515, 0.12678517],
       [0.6522414 , 0.8056322 , 0.35577488, 0.28378138, 0.64332926]],
      dtype=float32)

In [17]:
pd.DataFrame(probs).to_csv('vit_probs.csv', index=False)

In [None]:
shutil.copy('/content/vit_probs.csv', '/content/drive/MyDrive')