In [1]:
# download dataset
!gdown --id 1FoAEY_u0PTAlrscjEifi2om15A83wL78

# unzip dataset
!unzip -q EndoVis-18-VQA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1FoAEY_u0PTAlrscjEifi2om15A83wL78
From (redirected): https://drive.google.com/uc?id=1FoAEY_u0PTAlrscjEifi2om15A83wL78&confirm=t&uuid=60f52ea8-bdc1-489d-aec1-4c18b2d3c516
To: /content/EndoVis-18-VQA.zip
100% 2.71G/2.71G [00:25<00:00, 106MB/s]


In [2]:
# install libs
!pip install -q timm==0.9.12 fairscale==0.4.13 scikit-learn==1.3.2

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/60.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/266.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fairscale (pyproject.toml) .

In [3]:
# dataloader
import os
import glob

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pathlib import Path
from torchvision.transforms.functional import InterpolationMode

# utils
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_fscore_support

# main
import torch
import argparse
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import random

from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

# model
from transformers import GPT2Tokenizer, GPT2Model, ViTModel
from transformers import BertModel

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

### Dataloader

In [4]:
class EndoVis18VQAGPTClassification(Dataset):
    def __init__(self, seq, folder_head, folder_tail, transform=None):
        # define transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform

        # get files, questions and answers
        filenames = []
        for curr_seq in seq:
            filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines:
                self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' % (len(filenames), len(self.vqas)))

        # Labels
        self.labels = ['kidney', 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
                       'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction',
                       'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
                       'left-top', 'right-top', 'left-bottom', 'right-bottom']

    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        # get path
        qa_full_path = Path(self.vqas[idx][0])
        seq_path = qa_full_path.parents[2]
        file_name = self.vqas[idx][0].split('/')[-1]

        # img
        img_loc = os.path.join(seq_path, 'left_fr', file_name.split('_')[0] + '.png')
        raw_image = Image.open(img_loc).convert('RGB')
        img = self.transform(raw_image)

        # question and answer
        question = self.vqas[idx][1].split('|')[0]
        answer = self.vqas[idx][1].split('|')[1]
        label = self.labels.index(str(answer))

        return img_loc, img, question, label

### Utils

In [5]:
def save_clf_checkpoint(checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, Acc, final_args):
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'Acc': Acc,
             'model': model,
             'optimizer': optimizer,
             'final_args': final_args}
    filename = checkpoint_dir + 'Best.pth.tar.gz'
    torch.save(state, filename)

def adjust_learning_rate(optimizer, shrink_factor):
    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc


def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_precision_recall_fscore(y_true, y_pred):
    precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=1)
    return (precision, recall, fscore)

### Model

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PitVQANet(nn.Module):
    def __init__(self, num_class=18):  # 18/59
        super().__init__()

        # visual encoder
        model_name = "google/vit-base-patch16-224-in21k"
        self.visual_encoder = ViTModel.from_pretrained(model_name)

        # tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token  # end of string

        # text encoder
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        new_vocab_size = len(self.tokenizer)
        old_embeddings = self.text_encoder.embeddings.word_embeddings
        new_embeddings = nn.Embedding(new_vocab_size, old_embeddings.embedding_dim)
        new_embeddings.weight.data[:old_embeddings.num_embeddings, :] = old_embeddings.weight.data
        self.text_encoder.embeddings.word_embeddings = new_embeddings

        # text decoder
        self.gpt_decoder = GPT2Model.from_pretrained('gpt2')

        # intermediate layers
        self.intermediate_layer = nn.Linear(768, 512)
        self.se_layer = nn.Sequential(
            nn.Linear(512, 512),
            nn.Sigmoid()
        )
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.2)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, image, question):
        image = image.to(device)

        # visual encoder
        image_embeds = self.visual_encoder(image).last_hidden_state
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        # tokenization
        encoder_question = self.tokenizer(question, return_tensors="pt", truncation=True,
                                          padding='max_length', max_length=25).to(image.device)

        # text encoder
        text_embeds = self.text_encoder(
            input_ids=encoder_question.input_ids,
            attention_mask=encoder_question.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True
        ).last_hidden_state

        # text decoder
        gpt_output = self.gpt_decoder(inputs_embeds=text_embeds,
                                      encoder_attention_mask=encoder_question.attention_mask)
        decoder_output = gpt_output.last_hidden_state

        # average pool
        decoder_output = decoder_output.swapaxes(1, 2)
        decoder_output = F.adaptive_avg_pool1d(decoder_output, 1)
        decoder_output = decoder_output.swapaxes(1, 2).squeeze(1)

        out = self.intermediate_layer(decoder_output)
        out = torch.mul(out, self.se_layer(out))
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classification layer
        out = self.classifier(out)
        return out

### main

In [7]:
class InitParameter:
    def __init__(self):
        self.epochs = 60
        self.batch_size = 16
        self.workers = 4
        self.random_seed = 21
        self.lr = 0.00001
        self.question_len = 32
        self.num_class = 18  # 18/59

def seed_everything(seed=3407):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def train(train_dataloader, model, criterion, optimizer, epoch, device):
    model.train()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None

    for i, (_, images, questions, labels) in enumerate(tqdm(train_dataloader), 0):
        # labels
        labels = labels.to(device)
        outputs = model(image=images.to(device), question=questions)  # questions is a tuple
        loss = criterion(outputs, labels)  # calculate loss
        optimizer.zero_grad()
        loss.backward()  # calculate gradient
        optimizer.step()  # update parameters

        # print statistics
        total_loss += loss.item()

        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
        if label_true is None:  # accumulate true labels of the entire training set
            label_true = labels.data.cpu()
        else:
            label_true = torch.cat((label_true, labels.data.cpu()), 0)
        if label_pred is None:  # accumulate pred labels of the entire training set
            label_pred = predicted.data.cpu()
        else:
            label_pred = torch.cat((label_pred, predicted.data.cpu()), 0)
        if label_score is None:
            label_score = scores.data.cpu()
        else:
            label_score = torch.cat((label_score, scores.data.cpu()), 0)

    # loss and acc
    acc, c_acc = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred)
    precision, recall, f_score = calc_precision_recall_fscore(label_true, label_pred)
    print(f'Train: epoch: {epoch} loss: {total_loss} | Acc: {acc} | '
          f'Precision: {precision} | Recall: {recall} | F1 Score: {f_score}')
    return acc

def validate(val_loader, model, criterion, epoch, device):
    model.eval()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None
    file_names = list()

    with torch.no_grad():
        for i, (file_name, images, questions, labels) in enumerate(tqdm(val_loader), 0):
            # label
            labels = labels.to(device)

            # model forward pass
            outputs = model(image=images.to(device), question=questions)

            # loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
            label_true = labels.data.cpu() if label_true is None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred is None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score is None else torch.cat((label_score, scores.data.cpu()), 0)
            for f in file_name:
                file_names.append(f)  # not used

    acc = calc_acc(label_true, label_pred)
    c_acc = 0.0
    precision, recall, f_score = calc_precision_recall_fscore(label_true, label_pred)
    print(f'Test: epoch: {epoch} test loss: {total_loss} | test acc: {acc} | '
          f'test precision: {precision} | test recall: {recall} | test F1: {f_score}')
    return acc, c_acc, precision, recall, f_score


if __name__ == '__main__':
    # init parameters
    args = InitParameter()

    seed_everything(args.random_seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0

    # data location
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]
    folder_head = '/content/EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Classification/*.txt'

    # dataloader
    train_dataset = EndoVis18VQAGPTClassification(train_seq, folder_head, folder_tail)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    val_dataset = EndoVis18VQAGPTClassification(val_seq, folder_head, folder_tail)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    model = PitVQANet()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss().to(device)

    print('Start training.')
    for epoch in range(start_epoch, args.epochs+1):

        if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # train
        train_acc = train(train_dataloader=train_dataloader, model=model, criterion=criterion,
                          optimizer=optimizer, epoch=epoch, device=device)
        # validation
        test_acc, test_c_acc, test_precision, test_recall, test_f_score \
            = validate(val_loader=val_dataloader, model=model,
                       criterion=criterion, epoch=epoch, device=device)

        if test_acc >= best_results[0]:
            print('Best Epoch:', epoch)
            epochs_since_improvement = 0
            best_results[0] = test_acc
            best_epoch[0] = epoch
            # save_clf_checkpoint(args.checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, best_results[0], final_args=None)
    print('End training.')

Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Start training.


100%|██████████| 564/564 [07:38<00:00,  1.23it/s]


Train: epoch: 1 loss: 749.9205704331398 | Acc: 0.5531395606833814 | Precision: 0.304617882252632 | Recall: 0.21474973611447395 | F1 Score: 0.5430291712784366


100%|██████████| 174/174 [02:17<00:00,  1.27it/s]


Test: epoch: 1 test loss: 231.21328592300415 | test acc: 0.4373420007222824 | test precision: 0.6847414725862906 | test recall: 0.23905632341684419 | test F1: 0.21758988175548738
Best Epoch: 1


100%|██████████| 564/564 [07:55<00:00,  1.19it/s]


Train: epoch: 2 loss: 587.6171333491802 | Acc: 0.5849789216773907 | Precision: 0.6361702777936046 | Recall: 0.2843718203257731 | F1 Score: 0.28424603559969835


100%|██████████| 174/174 [02:18<00:00,  1.26it/s]


Test: epoch: 2 test loss: 209.4750133752823 | test acc: 0.4608161791260383 | test precision: 0.7371143098625357 | test recall: 0.2689257695157929 | test F1: 0.21417729393035229
Best Epoch: 2


  7%|▋         | 42/564 [00:42<08:44,  1.01s/it]


KeyboardInterrupt: 