# Skintelligence
## Custom model workflow

### Importing packages and data

In [1]:
import os
import copy
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models
from transformers import BertModel, BertTokenizer

In [2]:
proj_dir = os.path.join(os.getcwd(), os.pardir)
total_dataset_file = os.path.join(proj_dir, 'Data', 'Final', 'Final Complete Dataset.csv')
train_img_file = os.path.join(proj_dir, 'Data', 'Final', 'img_train.csv')
val_img_file = os.path.join(proj_dir, 'Data', 'Final', 'img_val.csv')
test_img_file = os.path.join(proj_dir, 'Data', 'Final', 'img_test.csv')

In [3]:
df = pd.read_csv(total_dataset_file)
df_img_train = pd.read_csv(train_img_file)
df_img_val = pd.read_csv(val_img_file)
df_img_test = pd.read_csv(test_img_file)

In [4]:
df.drop(['id', 'ori_file_path', 'caption_zh_polish_en', 'caption_zh', 'caption_zh_polish', 'remark', 'source', 'skin_tone', 'malignant', 'fitzpatrick_scale', 'fitzpatrick_centaur', 'nine_partition_label', 'three_partition_label', 'url', 'Do not consider this image'], axis=1, inplace=True)

In [5]:
df.columns

Index(['skincap_file_path', 'disease', 'question', 'answer', 'Vesicle',
       'Papule', 'Macule', 'Plaque', 'Abscess', 'Pustule', 'Bulla', 'Patch',
       'Nodule', 'Ulcer', 'Crust', 'Erosion', 'Excoriation', 'Atrophy',
       'Exudate', 'Purpura/Petechiae', 'Fissure', 'Induration', 'Xerosis',
       'Telangiectasia', 'Scale', 'Scar', 'Friable', 'Sclerosis',
       'Pedunculated', 'Exophytic/Fungating', 'Warty/Papillomatous',
       'Dome-shaped', 'Flat topped', 'Brown(Hyperpigmentation)', 'Translucent',
       'White(Hypopigmentation)', 'Purple', 'Yellow', 'Black', 'Erythema',
       'Comedo', 'Lichenification', 'Blue', 'Umbilicated', 'Poikiloderma',
       'Salmon', 'Wheal', 'Acuminate', 'Burrow', 'Gray', 'Pigmented', 'Cyst'],
      dtype='object')

In [6]:
# Column storing image name
img_name_col = 'skincap_file_path'

# Separating the medical annotation column names from the others
med_annot_names = [column for column in df_img_train.columns if column not in [img_name_col, 'disease']]

# Removing the columns that have less number of 1 values as we do not have enough representation of this label in the dataset
threshold = 450

column_sum = df_img_train[med_annot_names].sum()
med_annot_names = column_sum[column_sum > threshold].index

In [7]:
df_vqa_train = df[df[img_name_col].isin(df_img_train[img_name_col])]
df_vqa_val = df[df[img_name_col].isin(df_img_val[img_name_col])]
df_vqa_test = df[df[img_name_col].isin(df_img_test[img_name_col])]

In [8]:
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize questions and answers for the entire dataset
questions = [item for item in df['question']]
answers = [item for item in df['answer']]

question_encoding = tokenizer(questions, padding=True, truncation=True, return_tensors='pt')
answer_encoding = tokenizer(answers, padding=True, truncation=True, return_tensors='pt')

df['question_input_ids'] = question_encoding['input_ids'].tolist()
df['question_attention_mask'] = question_encoding['attention_mask'].tolist()

df['answer_input_ids'] = answer_encoding['input_ids'].tolist()
df['answer_attention_mask'] = answer_encoding['attention_mask'].tolist()

### Multilabel classification model

In [9]:
# Initialize the MLC model and load the pretrained weights
mlc_model = models.mobilenet_v3_large(weights=None)
num_ftrs_in = mlc_model.classifier[0].in_features
mlc_model.classifier = nn.Sequential(
    nn.Linear(num_ftrs_in, 1024),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(256, len(med_annot_names)),
    nn.Sigmoid()
)
mlc_model.load_state_dict(torch.load('mlc_mobilenet_checkpoint.pth'))
mlc_model.eval()  # Set to evaluation mode since we are only extracting features

# Freeze the MLC parameters so they are not updated during VQA training
for param in mlc_model.parameters():
    param.requires_grad = False

### Image Feature Extractor / CNN model

In [10]:
cnn_model = copy.deepcopy(mlc_model)
cnn_model.classifier = nn.Identity()

In [11]:
cnn_model

MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bi

### Question-Answer Encoding model

In [12]:
# Initialize the pretrained BERT model
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Set BERT to evaluation mode to generate embeddings without training
bert_model.eval()

# Disable gradient calculations for efficiency
#with torch.no_grad():
    # Forward pass to get BERT embeddings for the questions
    #question_outputs = bert_model(input_ids=question_input_ids, attention_mask=question_attention_mask)

# Extract the [CLS] token embedding
#question_embeddings = question_outputs.last_hidden_state[:, 0, :] #.tolist()  # Shape: (batch_size, hidden_size)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [13]:
# VQA Model without Attention
class VQAModel(nn.Module):
    def __init__(self, mlc_model, cnn_model, bert_model, hidden_size, vocab_size):
        super(VQAModel, self).__init__()
        self.mlc_model = mlc_model
        self.cnn_model = cnn_model
        self.bert_model = bert_model
        self.lstm = nn.LSTM(input_size=1280 + 768 + len(med_annot_names), hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, images, questions, medical_annotations):
        # Extract features from images using CNN
        image_features = self.cnn_model(images)  # (batch_size, feature_size)

        # Tokenize and embed the question using BERT
        input_ids = questions['input_ids']
        attention_mask = questions['attention_mask']

        question_embeddings = self.bert_model(input_ids, attention_mask=attention_mask).last_hidden_state
        question_embedding = question_embeddings.mean(dim=1)  # Mean across sequence length

        # Concatenate image features, question embedding, and medical annotations
        combined_features = torch.cat((image_features, question_embedding, medical_annotations), dim=1)

        # Pass combined features through LSTM
        lstm_out, (h_n, c_n) = self.lstm(combined_features.unsqueeze(1))  # (batch_size, seq_len=1, hidden_size)

        # Final output layer to predict the answer as a sequence of words
        output = self.fc(lstm_out[:, -1, :])  # Use the last output from LSTM

        return output

In [14]:
hidden_size = 512  # Adjust based on your needs
vocab_size = len(tokenizer.vocab)  # Size of your vocabulary for sentence generation
vqa_model = VQAModel(mlc_model, cnn_model, bert_model, hidden_size, vocab_size)

class AttentionLayer(nn.Module):
    def __init__(self, cnn_output_size, bert_hidden_size, annotation_size, attention_size):
        super(AttentionLayer, self).__init__()
        
        # Linear layers to compute attention scores
        self.image_fc = nn.Linear(cnn_output_size, attention_size)  # Image feature transformation
        self.question_fc = nn.Linear(bert_hidden_size, attention_size)  # Question embedding transformation
        self.annotation_fc = nn.Linear(annotation_size, attention_size)  # Annotation transformation
        
        # Output layer to get attention weights
        self.attention_weight = nn.Linear(attention_size, 1)

    def forward(self, image_features, question_embeddings, annotations):
        """
        image_features: CNN output (batch_size, cnn_output_size)
        question_embeddings: BERT question embeddings (batch_size, bert_hidden_size)
        annotations: Multilabel annotation predictions (batch_size, annotation_size)
        """
        
        # Transform inputs into attention space
        image_att = self.image_fc(image_features)  # Shape: (batch_size, attention_size)
        question_att = self.question_fc(question_embeddings)  # Shape: (batch_size, attention_size)
        annotation_att = self.annotation_fc(annotations)  # Shape: (batch_size, attention_size)
        
        # Combine the transformed features (sum of contributions from image, question, and annotations)
        combined_att = image_att + question_att + annotation_att
        
        # Apply non-linearity (ReLU) and compute attention scores
        attention_scores = F.relu(combined_att)
        
        # Compute attention weights (normalized across all inputs)
        attention_weights = torch.sigmoid(self.attention_weight(attention_scores))  # Shape: (batch_size, 1)
        
        # Combine the inputs based on the attention weights
        combined_features = (attention_weights * image_features) + (attention_weights * question_embeddings) + (attention_weights * annotations)
        
        return combined_features  # Shape: (batch_size, cnn_output_size)