###### -----------------START--------------------------------------------

In [1]:
import json

In [2]:
import os

In [3]:
from tqdm import tqdm

In [4]:
import matplotlib.pyplot as plt

In [5]:
train_file_path = '/home/aritra/cric/train_questions.json'
val_file_path = '/home/aritra/cric/val_questions.json'
test_file_path = '/home/aritra/cric/test_v1_questions.json'

In [6]:
# Training Set

with open(train_file_path, "r") as file:
     train_json = json.load(file)

In [7]:
# Validation Set

with open(val_file_path, "r") as file:
     val_json = json.load(file)

In [8]:
# Test Set

with open(test_file_path, "r") as file:
     test_json = json.load(file)

In [9]:
len(train_json)

365235

In [10]:
len(val_json)

43112

In [11]:
len(test_json)

86003

In [12]:
train_json[1099]['question']

'which brown animal walking in the field could be used for transporting people'

In [13]:
val_json[1099]['question']

'is there an object that is a type of public transports'

In [14]:
test_json[1099]['question']

'can the ceramic bird spread wings'

### ------------------------------Extracting Data of Training Set-------------------------------------------------------------------------------



In [15]:
questionList = []
answerList = []
imgList = []

In [16]:
train_json[2]['image_id']

1005

#### iter 1: from 0 , 149000 -> error1.txt -> 159
#### iter 2: from 150000 , 240000 -> error2.txt -> 34
#### iter 3: from 240000 , 365235 ->error3.txt -> 121

In [17]:
# verifying
indexToExclude = []

with open('error1.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExclude.append(number)
        
with open('error2.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExclude.append(number)
        
with open('error3.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExclude.append(number)

In [18]:
len(indexToExclude)

314

In [19]:
for i in tqdm(range(len(train_json))):
    
    if i in indexToExclude:
        continue
        
    pointer = train_json[i]
    
    questionList.append(pointer['question'])
    answerList.append(pointer['answer'])
    imgList.append(pointer['image_id'])

100%|███████████████████████████████████| 365235/365235 [00:01<00:00, 319729.85it/s]


In [20]:
len(questionList), len(answerList), len(imgList)

(364921, 364921, 364921)

In [21]:
len(list(set(answerList)))

1442

### ---------------------------------------Map Creation--------------------------------------------------------

In [22]:
def findUnique(targetList):
    
    uniqueList = []
    
    for word in targetList:
        if word not in uniqueList:
            uniqueList.append(word)
    
    return uniqueList

In [23]:
len(findUnique(answerList))

1442

In [24]:
# creating word to number mapping

mapping = {}
counter = 0

uniqueAnsList = findUnique(answerList)

for word in uniqueAnsList:
    
    if word not in mapping:
        
        mapping[word] = counter
        counter += 1

In [25]:
uniqueAnsList[0:5]

['no', 'small', 'picture', 'table', 'bookshelf']

In [26]:
numOfClasses = max(mapping.values())
numOfClasses

1441

In [27]:
len(mapping)

1442

In [28]:
# creating number to word mapping

reverse_mapping = dict([(value, key) for key, value in mapping.items()])

### --------------------------------------Processing of Training Set--------------------------------------------------------------------

In [29]:
labels = []

for i in range(len(answerList)):
    labels.append( mapping[ answerList[i] ] )

In [30]:
len(labels)

364921

In [31]:
scores = []

for i in tqdm(range(len(answerList))):
    
    s = [0] * (numOfClasses+1)
    s[ mapping[ answerList[i]] ] = 1
    
    scores.append(s)

100%|████████████████████████████████████| 364921/364921 [00:03<00:00, 91805.66it/s]


In [32]:
len(scores)

364921

In [33]:
imgPathList = []
filepath = '/home/aritra/cric/images/img/'

for i in tqdm(range(len(imgList))):
    
    imgName = str(imgList[i]) + '.jpg'
    concatedPath = os.path.join(filepath,imgName)
    
    imgPathList.append(concatedPath)

100%|███████████████████████████████████| 364921/364921 [00:00<00:00, 767628.01it/s]


In [34]:
from datasets import load_dataset
from datasets import Dataset
import datasets
from PIL import Image
import torch

In [35]:
imgPathList[0:5]

['/home/aritra/cric/images/img/1000.jpg',
 '/home/aritra/cric/images/img/1005.jpg',
 '/home/aritra/cric/images/img/1005.jpg',
 '/home/aritra/cric/images/img/1005.jpg',
 '/home/aritra/cric/images/img/1008.jpg']

In [36]:
len(imgPathList)

364921

In [37]:
listToDictionary = {'questions':questionList, 'labels': labels, 'scores': scores, 'images':imgPathList}
modified_train_set = Dataset.from_dict(listToDictionary)

In [38]:
# mapping each filepath to images in the directory

modified_train_set = modified_train_set.cast_column("images", datasets.Image())

In [39]:
modified_train_set

Dataset({
    features: ['questions', 'labels', 'scores', 'images'],
    num_rows: 364921
})

## -----------------------------------Extracting Validation Set---------------------------------------------

In [40]:
questionList_val = []
answerList_val = []
imgList_val = []

In [41]:
# collecting the index containing errorneous images

indexToExcludeVal = []
with open('error_validation.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExcludeVal.append(number)

with open('error_validation2.txt', 'r') as file:
    for line in file:
        number = int(line.strip())  # Convert the read line to an integer
        indexToExcludeVal.append(number)


In [42]:
# excluding the index containing errorneous images

for i in tqdm(range(len(val_json))):
    
    if (i in indexToExcludeVal):
        continue
        
    pointer = val_json[i]
    
    questionList_val.append(pointer['question'])
    answerList_val.append(pointer['answer'])
    imgList_val.append(pointer['image_id'])

100%|██████████████████████████████████████| 43112/43112 [00:02<00:00, 16205.99it/s]


43112 -> 43068 -> 33175

In [43]:
len(questionList_val), len(answerList_val), len(imgList_val)

(33175, 33175, 33175)

In [44]:
uniqueAnswerListVal = list(set(answerList_val))
len(uniqueAnswerListVal)

266

In [45]:
# check if all the uniques answers are present in the mapping

y,n = 0,0
store = []
for i in range(len(answerList_val)):
    
    word = answerList_val[i]
    
    if word in mapping:
        y += 1
    else:
        n+=1
        store.append(i)

In [46]:
y

33175

### --------------------------------------------------------Processing Validation Set-------------------------------------------------------

In [47]:
labels_val = []

for i in range(len(answerList_val)):
    labels_val.append( mapping[ answerList_val[i] ] )

In [48]:
len(labels_val)

33175

In [49]:
scores_val = []

for i in tqdm(range(len(answerList_val))):
    
    s = [0] * (numOfClasses+1)
    s[ mapping[ answerList_val[i]] ] = 1
    
    scores_val.append(s)

100%|██████████████████████████████████████| 33175/33175 [00:00<00:00, 92510.18it/s]


In [50]:
len(scores_val)

33175

In [51]:
imgPathList_val = []
filepath = '/home/aritra/cric/images/img/'

for i in tqdm(range(len(imgList_val))):
    
    imgName = str(imgList_val[i]) + '.jpg'
    concatedPath = os.path.join(filepath,imgName)
    
    imgPathList_val.append(concatedPath)

100%|█████████████████████████████████████| 33175/33175 [00:00<00:00, 762001.44it/s]


In [52]:
imgPathList_val[0:5]

['/home/aritra/cric/images/img/1003.jpg',
 '/home/aritra/cric/images/img/1003.jpg',
 '/home/aritra/cric/images/img/1018.jpg',
 '/home/aritra/cric/images/img/1018.jpg',
 '/home/aritra/cric/images/img/1027.jpg']

In [53]:
# creating HF dataset to map images fast of Val_set

listToDictionary = {'questions':questionList_val, 'labels':labels_val, 'scores':scores_val, 'images':imgPathList_val}
modified_val_set = Dataset.from_dict(listToDictionary)

In [54]:
# mapping each filepath of Val Set to images in the directory

modified_val_set = modified_val_set.cast_column("images", datasets.Image())

### -------------------------------------------Extracting Color Questions Set-------------------------------------------------


In [55]:
# color questions of the train set is stored in this file.
# objective is to train a fresh model on only color based questions

indices = []
with open('./text_files/color_questions_indices_train_set.txt', 'r') as file:
    for number in file:
        number = int(number.strip())
        indices.append(number)


In [56]:
indices[0:5]

[0, 1, 16, 29, 32]

In [57]:
questionList_color = []
answerList_color = []
imgList_color = []

In [58]:
# from the predefined list of color indices of the train set here the questions,answer,img is being copied 

for i in indices:
                
    questionList_color.append( questionList[i] )
    answerList_color.append( answerList[i] )
    imgList_color.append( imgList[i] )

In [59]:
questionList_color[10:15]

['which green thing near the green house could be opened or closed',
 'what type of object is on back of the tan object that I can use for sitting on',
 'can the black electronic device that is wearing the hand control small electrical appliance',
 'which color is the object that is on back of the seat and is a type of electronic device',
 'is there a yellow vehicle that can travel on road']

In [60]:
len(questionList_color), len(answerList_color), len(imgList_color)

(149163, 149163, 149163)

In [61]:
uniqueAnswerListColor = list(set(answerList_color))
len(uniqueAnswerListColor)

1065

In [62]:
# check if all the uniques answers are present in the mapping

y,n = 0,0
store = []
for i in range(len(answerList_color)):
    
    word = answerList_color[i]
    
    if word in mapping:
        y += 1
    else:
        n+=1
        store.append(i)

In [63]:
y

149163

In [64]:
# Processing of color set

labels_color = []

for i in range(len(answerList_color)):
    labels_color.append( mapping[ answerList_color[i] ] )

In [65]:
len(labels_color)

149163

In [66]:
scores_color = []

for i in tqdm(range(len(answerList_color))):
    
    s = [0] * (numOfClasses+1)
    s[ mapping[ answerList_color[i]] ] = 1
    
    scores_color.append(s)

100%|████████████████████████████████████| 149163/149163 [00:01<00:00, 90223.22it/s]


In [67]:
len(scores_color)

149163

In [68]:
imgPathList_color = []
filepath = '/home/aritra/cric/images/img/'

for i in tqdm(range(len(imgList_color))):
    
    imgName = str(imgList_color[i]) + '.jpg'
    concatedPath = os.path.join(filepath,imgName)
    
    imgPathList_color.append(concatedPath)

100%|███████████████████████████████████| 149163/149163 [00:00<00:00, 769226.47it/s]


In [69]:
len(imgPathList_color)

149163

In [70]:
imgPathList_val[0:5]

['/home/aritra/cric/images/img/1003.jpg',
 '/home/aritra/cric/images/img/1003.jpg',
 '/home/aritra/cric/images/img/1018.jpg',
 '/home/aritra/cric/images/img/1018.jpg',
 '/home/aritra/cric/images/img/1027.jpg']

In [71]:
# creating HF dataset to map images fast of Val_set

listToDictionary = {'questions':questionList_color, 'labels':labels_color, 'scores':scores_color, 'images':imgPathList_color}
train_color_set = Dataset.from_dict(listToDictionary)

In [72]:
# mapping each filepath of Val Set to images in the directory

train_color_set = train_color_set.cast_column("images", datasets.Image())

### -------------------------------------------Extracting Test Set-------------------------------------------------


In [73]:
questionList_test = []
answerList_test = []
imgList_test = []

In [74]:
indexToExcludeTest = []

with open('error_testSet1.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExcludeTest.append(number)
        
with open('errorTestSet2.txt', 'r') as file:
    for line in file:
        number = int(line.strip())
        indexToExcludeTest.append(number)

In [75]:
len(indexToExcludeTest)

14150

In [76]:
for i in tqdm(range(len(test_json))):
    
    if i in indexToExcludeTest:
        continue
        
    pointer = test_json[i]
    
    questionList_test.append(pointer['question'])
    answerList_test.append(pointer['answer'])
    imgList_test.append(pointer['image_id'])

100%|██████████████████████████████████████| 86003/86003 [00:07<00:00, 11041.05it/s]


86003 -> 71863

### -------------------------------------- Processing Test Set ----------------------------------------------------------------------------

In [77]:
# check if all the uniques answers are present in the mapping

y,n = 0,0
store = []
for i in range(len(answerList_test)):
    
    word = answerList_test[i]
    
    if word in mapping:
        y += 1
    else:
        n+=1
        store.append(i)

In [78]:
y

71863

In [79]:
labels_test = []

for i in range(len(answerList_test)):
    labels_test.append( mapping[ answerList_test[i] ] )

In [80]:
len(labels_test)

71863

In [81]:
scores_test = []

for i in tqdm(range(len(answerList_test))):
    
    s = [0] * (numOfClasses+1)
    s[ mapping[ answerList_test[i]] ] = 1
    
    scores_test.append(s)

100%|██████████████████████████████████████| 71863/71863 [00:00<00:00, 90940.69it/s]


In [82]:
len(scores_test)

71863

In [83]:
imgPathList_test = []
filepath = '/home/aritra/cric/images/img/'

for i in tqdm(range(len(imgList_test))):
    
    imgName = str(imgList_test[i]) + '.jpg'
    concatedPath = os.path.join(filepath,imgName)
    
    imgPathList_test.append(concatedPath)

100%|█████████████████████████████████████| 71863/71863 [00:00<00:00, 774324.92it/s]


In [84]:
len(imgPathList_test)

71863

In [85]:
imgPathList_test[0:5]

['/home/aritra/cric/images/img/1004.jpg',
 '/home/aritra/cric/images/img/1004.jpg',
 '/home/aritra/cric/images/img/1004.jpg',
 '/home/aritra/cric/images/img/1004.jpg',
 '/home/aritra/cric/images/img/1004.jpg']

In [86]:
# creating HF dataset to map images fast of test_set

listToDictionary = {'questions':questionList_test, 'labels':labels_test, 'scores':scores_test, 'images':imgPathList_test}
modified_test_set = Dataset.from_dict(listToDictionary)

In [87]:
# mapping each filepath of test Set to images in the directory

modified_test_set = modified_test_set.cast_column("images", datasets.Image())

### -------------------------------End of Processing----------------------------------------------------------------------------

In [88]:
from transformers import ViltProcessor, ViltForQuestionAnswering

In [89]:
from transformers import ViltConfig
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

In [90]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [91]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

In [92]:
model = ViltForQuestionAnswering.from_pretrained("model_chkpts/vilt-mlm-classification-model/vilt_mlm_mod_e4_cric_trained/", id2label = reverse_mapping, label2id = mapping).to(device)

In [93]:
from torch.utils.data import DataLoader
from datasets import Dataset
import numpy as np

In [94]:
class cric_dataset(Dataset):
    
    def __init__(self, dataset, processor):
        self.processor = processor
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,idx):
        
        #print(idx)
        item = self.dataset[idx]

        #print(item)
        
        encodings = self.processor(images = item["images"], text = item["questions"], padding="max_length", truncation=True, return_tensors = "pt")
        encodings = {k:v.squeeze() for k,v in encodings.items()}
                                
        encodings['labels'] = torch.tensor(item['scores'], dtype = torch.float32)
        
        return encodings

In [95]:
train_dataset_object = cric_dataset(modified_train_set, processor)

In [96]:
val_dataset_object = cric_dataset(modified_val_set, processor)

In [97]:
test_dataset_object = cric_dataset(modified_test_set, processor)

In [98]:
color_dataset_object = cric_dataset(train_color_set, processor)

In [99]:
def collate_fn(batch):
  
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    token_type_ids = [item['token_type_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
        
    # create padded pixel values and corresponding pixel mask
    
    encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

    # create new batch
    
    batch = {}
    
    batch['input_ids'] = torch.stack(input_ids)
    batch['attention_mask'] = torch.stack(attention_mask)
    batch['token_type_ids'] = torch.stack(token_type_ids)
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = torch.stack(labels, dim = 0 )

    return batch


In [100]:
train_dataloader = DataLoader(color_dataset_object, collate_fn = collate_fn, shuffle = True, batch_size = 32)

In [101]:
batch = next(iter(train_dataloader))

In [102]:
for k,v in batch.items():
    print(k, v.shape)
    print()
    
#print(batch.keys())

input_ids torch.Size([32, 40])

attention_mask torch.Size([32, 40])

token_type_ids torch.Size([32, 40])

pixel_values torch.Size([32, 3, 608, 608])

pixel_mask torch.Size([32, 608, 608])

labels torch.Size([32, 1442])



In [103]:
tot_number_of_steps = len(train_dataloader)
tot_number_of_steps

4662

In [104]:
# for fp16 precision

from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler

scaler = GradScaler()

In [105]:
# No Visualisation for this Test Notebook

#from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter()

## Model Training Loop

In [108]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-5)

model.train()

for epoch in tqdm(range(1)):  

    print(f"Epoch: {epoch}")

    for idx, batch in enumerate(train_dataloader):

        batch = {k:v.to(device) for k,v in batch.items()}

        optimizer.zero_grad()
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):

            outputs = model(**batch)
            loss = outputs.loss

        print(idx,"-> Loss:", loss.item())
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)

        scaler.update()
        
        # Plots in tensorboard
    
        if (idx != 0 ) and (idx % 1000 == 0):
            
            model.eval()
            
            acc_score_test = calculateAccuracyTest()
            acc_score_val, validationLoss = calculateAccuracyVal()
            
            print(f'\nValidation Accuracy: {acc_score_val}, Test Accuracy: {acc_score_test} \n')
            
            #writer.add_scalar('Training Loss', loss.item(), epoch * tot_number_of_steps + idx)
            #writer.add_scalar('Validation Loss', validationLoss, epoch * tot_number_of_steps + idx)

            #writer.add_scalar('Accuracy Score On Val Set', acc_score_val, epoch * tot_number_of_steps + idx)
            #writer.add_scalar('Accuracy Score On Test Set', acc_score_test, epoch * tot_number_of_steps + idx)
            
            model.train()
            
    # Save model checkpoint
    
    save_path = os.path.join('./model_chkpts/test/', 'vilt_mlm_color_e' + '5' + '_cric_trained')
    model.save_pretrained(save_path)
    print("Model Saved At: ", epoch)
    
#writer.close()



  0%|                                                         | 0/1 [00:00<?, ?it/s]

Epoch: 0


  return F.conv2d(input, weight, bias, self.stride,


0 -> Loss: 0.5371490716934204
1 -> Loss: 0.23276397585868835
2 -> Loss: 0.35655367374420166
3 -> Loss: 0.49922704696655273
4 -> Loss: 0.6733691692352295
5 -> Loss: 0.5433361530303955
6 -> Loss: 1.0520079135894775
7 -> Loss: 0.4814547300338745
8 -> Loss: 0.4572630524635315
9 -> Loss: 0.3161604404449463
10 -> Loss: 0.5157613754272461
11 -> Loss: 0.5291028618812561
12 -> Loss: 0.4964756965637207
13 -> Loss: 0.5185451507568359
14 -> Loss: 0.04080929234623909
15 -> Loss: 0.5630484819412231
16 -> Loss: 0.32479608058929443
17 -> Loss: 0.6960514783859253
18 -> Loss: 0.7995201349258423
19 -> Loss: 0.44206932187080383
20 -> Loss: 0.3656335175037384
21 -> Loss: 0.4552389681339264
22 -> Loss: 0.2756907045841217
23 -> Loss: 0.42417505383491516
24 -> Loss: 0.8510345816612244
25 -> Loss: 0.26363715529441833
26 -> Loss: 0.4432034194469452
27 -> Loss: 0.705394446849823
28 -> Loss: 0.5126279592514038
29 -> Loss: 0.7259231805801392
30 -> Loss: 0.4939831495285034
31 -> Loss: 0.7674992680549622
32 -> Loss:

259 -> Loss: 0.27108097076416016
260 -> Loss: 0.527826189994812
261 -> Loss: 0.5838673710823059
262 -> Loss: 0.7018933296203613
263 -> Loss: 0.14541533589363098
264 -> Loss: 0.45691418647766113
265 -> Loss: 0.4368586540222168
266 -> Loss: 0.13244548439979553
267 -> Loss: 0.2880104184150696
268 -> Loss: 0.4282197952270508
269 -> Loss: 0.21098196506500244
270 -> Loss: 0.2622842490673065
271 -> Loss: 0.7317684292793274
272 -> Loss: 0.40075239539146423
273 -> Loss: 0.8611503839492798
274 -> Loss: 0.29605022072792053
275 -> Loss: 0.21642595529556274
276 -> Loss: 0.6913865804672241
277 -> Loss: 0.4979076683521271
278 -> Loss: 0.3675096333026886
279 -> Loss: 0.9330558180809021
280 -> Loss: 0.3321142792701721
281 -> Loss: 0.44862326979637146
282 -> Loss: 0.4151404798030853
283 -> Loss: 0.36282452940940857
284 -> Loss: 0.4901552200317383
285 -> Loss: 0.06830372661352158
286 -> Loss: 0.7116612195968628
287 -> Loss: 0.8327693343162537
288 -> Loss: 0.7594789266586304
289 -> Loss: 0.574701786041259

514 -> Loss: 0.3638819456100464
515 -> Loss: 0.5962555408477783
516 -> Loss: 0.38886767625808716
517 -> Loss: 0.6665236949920654
518 -> Loss: 0.2738398313522339
519 -> Loss: 0.22141176462173462
520 -> Loss: 0.22383053600788116
521 -> Loss: 0.8388539552688599
522 -> Loss: 0.26361143589019775
523 -> Loss: 0.5790871381759644
524 -> Loss: 0.367141991853714
525 -> Loss: 0.3153378963470459
526 -> Loss: 0.8531593084335327
527 -> Loss: 0.48080137372016907
528 -> Loss: 0.4085765480995178
529 -> Loss: 0.7370085120201111
530 -> Loss: 0.40140289068222046
531 -> Loss: 0.45800086855888367
532 -> Loss: 0.530412495136261
533 -> Loss: 0.6282362937927246
534 -> Loss: 0.1508878469467163
535 -> Loss: 0.772908627986908
536 -> Loss: 0.33029472827911377
537 -> Loss: 0.6174072623252869
538 -> Loss: 0.3727363049983978
539 -> Loss: 0.31696566939353943
540 -> Loss: 0.7487247586250305
541 -> Loss: 0.5088855028152466
542 -> Loss: 0.30606651306152344
543 -> Loss: 0.4410032033920288
544 -> Loss: 0.6288946866989136
5

769 -> Loss: 1.060248613357544
770 -> Loss: 0.7374250888824463
771 -> Loss: 0.870110273361206
772 -> Loss: 0.24526375532150269
773 -> Loss: 0.49328744411468506
774 -> Loss: 0.18299981951713562
775 -> Loss: 0.5860162973403931
776 -> Loss: 0.6000025272369385
777 -> Loss: 0.5822156071662903
778 -> Loss: 0.8296589255332947
779 -> Loss: 0.4015807509422302
780 -> Loss: 0.4388345777988434
781 -> Loss: 0.2461174726486206
782 -> Loss: 0.47471311688423157
783 -> Loss: 0.34213709831237793
784 -> Loss: 0.2503000497817993
785 -> Loss: 0.5483832359313965
786 -> Loss: 0.4859844744205475
787 -> Loss: 0.5199170112609863
788 -> Loss: 0.22734658420085907
789 -> Loss: 0.3786564767360687
790 -> Loss: 0.45610886812210083
791 -> Loss: 0.4401465654373169
792 -> Loss: 0.6520127058029175
793 -> Loss: 0.6121006608009338
794 -> Loss: 0.7271037101745605
795 -> Loss: 0.798469603061676
796 -> Loss: 0.6800267100334167
797 -> Loss: 0.6739587783813477
798 -> Loss: 0.19953536987304688
799 -> Loss: 0.2520343065261841
800

1021 -> Loss: 0.41861724853515625
1022 -> Loss: 0.2911081910133362
1023 -> Loss: 0.4622322618961334
1024 -> Loss: 0.49772074818611145
1025 -> Loss: 0.14924265444278717
1026 -> Loss: 0.20530125498771667
1027 -> Loss: 0.60121750831604
1028 -> Loss: 0.34848424792289734
1029 -> Loss: 0.4601617455482483
1030 -> Loss: 0.4411923289299011
1031 -> Loss: 0.4913756847381592
1032 -> Loss: 0.29567456245422363
1033 -> Loss: 0.6030915379524231
1034 -> Loss: 0.6569069623947144
1035 -> Loss: 0.45350611209869385
1036 -> Loss: 1.2566273212432861
1037 -> Loss: 0.6887289881706238
1038 -> Loss: 0.5198811888694763
1039 -> Loss: 0.2637079060077667
1040 -> Loss: 0.5428491234779358
1041 -> Loss: 0.4652807116508484
1042 -> Loss: 0.21622249484062195
1043 -> Loss: 0.4863138794898987
1044 -> Loss: 0.11350523680448532
1045 -> Loss: 0.7810757160186768
1046 -> Loss: 0.5388825535774231
1047 -> Loss: 0.3297642767429352
1048 -> Loss: 0.26083287596702576
1049 -> Loss: 0.29000335931777954
1050 -> Loss: 0.49779096245765686


1268 -> Loss: 0.4895821213722229
1269 -> Loss: 0.3977683484554291
1270 -> Loss: 0.5042760372161865
1271 -> Loss: 1.4033793210983276
1272 -> Loss: 0.3048066794872284
1273 -> Loss: 0.5521018505096436
1274 -> Loss: 0.5059213042259216
1275 -> Loss: 0.4293927550315857
1276 -> Loss: 0.28390559554100037
1277 -> Loss: 0.545642077922821
1278 -> Loss: 0.4833724796772003
1279 -> Loss: 0.38583970069885254
1280 -> Loss: 0.4086553752422333
1281 -> Loss: 0.16834409534931183
1282 -> Loss: 0.34124815464019775
1283 -> Loss: 0.41192299127578735
1284 -> Loss: 0.7259250283241272
1285 -> Loss: 0.2692745327949524
1286 -> Loss: 1.000659465789795
1287 -> Loss: 0.4118349850177765
1288 -> Loss: 0.38503795862197876
1289 -> Loss: 0.6695672273635864
1290 -> Loss: 0.2621490955352783
1291 -> Loss: 0.8266687393188477
1292 -> Loss: 0.8023719191551208
1293 -> Loss: 0.26815265417099
1294 -> Loss: 0.5144680738449097
1295 -> Loss: 0.7854167222976685
1296 -> Loss: 0.4181743562221527
1297 -> Loss: 0.5840034484863281
1298 -> 

1515 -> Loss: 0.4128221869468689
1516 -> Loss: 0.3100527226924896
1517 -> Loss: 0.3835766911506653
1518 -> Loss: 0.3094867467880249
1519 -> Loss: 0.7600144147872925
1520 -> Loss: 0.6223241090774536
1521 -> Loss: 0.28757399320602417
1522 -> Loss: 0.27379336953163147
1523 -> Loss: 0.71135014295578
1524 -> Loss: 0.2654394805431366
1525 -> Loss: 0.1395794302225113
1526 -> Loss: 0.7617769837379456
1527 -> Loss: 0.46322715282440186
1528 -> Loss: 0.5118148326873779
1529 -> Loss: 0.26710593700408936
1530 -> Loss: 0.45770829916000366
1531 -> Loss: 0.46733373403549194
1532 -> Loss: 0.16879218816757202
1533 -> Loss: 0.09633130580186844
1534 -> Loss: 0.5497941970825195
1535 -> Loss: 0.4812707006931305
1536 -> Loss: 0.7690078616142273
1537 -> Loss: 0.18623501062393188
1538 -> Loss: 0.219236820936203
1539 -> Loss: 0.22053927183151245
1540 -> Loss: 0.12661543488502502
1541 -> Loss: 0.4444517195224762
1542 -> Loss: 0.5830420255661011
1543 -> Loss: 0.42833131551742554
1544 -> Loss: 0.14887996017932892


1762 -> Loss: 0.7224410176277161
1763 -> Loss: 0.4359722137451172
1764 -> Loss: 0.5263280868530273
1765 -> Loss: 0.1656692773103714
1766 -> Loss: 0.5483464002609253
1767 -> Loss: 0.4483884871006012
1768 -> Loss: 0.15624430775642395
1769 -> Loss: 0.32391566038131714
1770 -> Loss: 0.2612667679786682
1771 -> Loss: 0.3677468001842499
1772 -> Loss: 0.5689023733139038
1773 -> Loss: 0.1978551596403122
1774 -> Loss: 0.40585091710090637
1775 -> Loss: 0.6440900564193726
1776 -> Loss: 0.6246020793914795
1777 -> Loss: 0.4571179449558258
1778 -> Loss: 0.6552560329437256
1779 -> Loss: 1.0301792621612549
1780 -> Loss: 0.28834250569343567
1781 -> Loss: 0.8422187566757202
1782 -> Loss: 0.4712551534175873
1783 -> Loss: 0.4825180768966675
1784 -> Loss: 0.805219292640686
1785 -> Loss: 0.35316893458366394
1786 -> Loss: 0.4764402210712433
1787 -> Loss: 0.4076770544052124
1788 -> Loss: 0.6315345764160156
1789 -> Loss: 0.4098452031612396
1790 -> Loss: 0.823925793170929
1791 -> Loss: 0.3200361728668213
1792 ->

2008 -> Loss: 0.4427875578403473
2009 -> Loss: 0.3073482811450958
2010 -> Loss: 0.18806962668895721
2011 -> Loss: 0.6840811371803284
2012 -> Loss: 0.4952585995197296
2013 -> Loss: 0.8758946657180786
2014 -> Loss: 0.17238108813762665
2015 -> Loss: 0.8759961724281311
2016 -> Loss: 0.6940659284591675
2017 -> Loss: 0.44242367148399353
2018 -> Loss: 0.34295615553855896
2019 -> Loss: 0.4370356798171997
2020 -> Loss: 0.10729019343852997
2021 -> Loss: 0.9113518595695496
2022 -> Loss: 0.4941244125366211
2023 -> Loss: 0.34972083568573
2024 -> Loss: 0.1880316287279129
2025 -> Loss: 0.40662112832069397
2026 -> Loss: 0.40659061074256897
2027 -> Loss: 0.815680742263794
2028 -> Loss: 0.6695623993873596
2029 -> Loss: 0.558904767036438
2030 -> Loss: 0.4200460910797119
2031 -> Loss: 0.38119643926620483
2032 -> Loss: 0.3801000416278839
2033 -> Loss: 0.60248863697052
2034 -> Loss: 0.3282269835472107
2035 -> Loss: 0.5027815103530884
2036 -> Loss: 0.31355011463165283
2037 -> Loss: 0.20201252400875092
2038 -

2256 -> Loss: 0.17162398993968964
2257 -> Loss: 0.4997127950191498
2258 -> Loss: 0.2614096403121948
2259 -> Loss: 0.6517490744590759
2260 -> Loss: 0.351227343082428
2261 -> Loss: 0.23863670229911804
2262 -> Loss: 0.7055375576019287
2263 -> Loss: 0.5678477883338928
2264 -> Loss: 0.2729538679122925
2265 -> Loss: 0.16993127763271332
2266 -> Loss: 0.8860474824905396
2267 -> Loss: 0.7084066271781921
2268 -> Loss: 0.3576844036579132
2269 -> Loss: 0.33120450377464294
2270 -> Loss: 0.6124591827392578
2271 -> Loss: 0.7007219195365906
2272 -> Loss: 0.7044221758842468
2273 -> Loss: 0.4011329114437103
2274 -> Loss: 0.5105149745941162
2275 -> Loss: 0.5226568579673767
2276 -> Loss: 0.30591827630996704
2277 -> Loss: 0.8432247042655945
2278 -> Loss: 0.6145723462104797
2279 -> Loss: 0.7241942882537842
2280 -> Loss: 0.543178915977478
2281 -> Loss: 0.39905861020088196
2282 -> Loss: 0.37991437315940857
2283 -> Loss: 0.17450794577598572
2284 -> Loss: 0.8730781674385071
2285 -> Loss: 0.3537532091140747
2286

2503 -> Loss: 0.7622277140617371
2504 -> Loss: 0.22580160200595856
2505 -> Loss: 0.6766939759254456
2506 -> Loss: 0.5180466771125793
2507 -> Loss: 0.6503024101257324
2508 -> Loss: 0.4537903070449829
2509 -> Loss: 0.44295811653137207
2510 -> Loss: 1.4277722835540771
2511 -> Loss: 1.0786546468734741
2512 -> Loss: 0.7096835374832153
2513 -> Loss: 0.3406650125980377
2514 -> Loss: 0.45332494378089905
2515 -> Loss: 0.5793893337249756
2516 -> Loss: 0.45234400033950806
2517 -> Loss: 0.2788439393043518
2518 -> Loss: 0.7448671460151672
2519 -> Loss: 0.06367488205432892
2520 -> Loss: 0.7035291194915771
2521 -> Loss: 0.7751990556716919
2522 -> Loss: 0.3338624835014343
2523 -> Loss: 0.37497276067733765
2524 -> Loss: 0.44651931524276733
2525 -> Loss: 0.4220639765262604
2526 -> Loss: 0.329913854598999
2527 -> Loss: 0.3794189393520355
2528 -> Loss: 0.161190927028656
2529 -> Loss: 0.5637768507003784
2530 -> Loss: 0.6462589502334595
2531 -> Loss: 0.6204675436019897
2532 -> Loss: 0.903236985206604
2533 -

2750 -> Loss: 0.29512736201286316
2751 -> Loss: 0.7316688299179077
2752 -> Loss: 0.15580902993679047
2753 -> Loss: 0.8404857516288757
2754 -> Loss: 0.08417962491512299
2755 -> Loss: 0.400633841753006
2756 -> Loss: 0.7255812287330627
2757 -> Loss: 0.2553991973400116
2758 -> Loss: 0.13882137835025787
2759 -> Loss: 0.8030213117599487
2760 -> Loss: 0.2520498037338257
2761 -> Loss: 0.32619917392730713
2762 -> Loss: 0.6761478185653687
2763 -> Loss: 0.6059584021568298
2764 -> Loss: 0.29340195655822754
2765 -> Loss: 0.5281155109405518
2766 -> Loss: 0.3348442316055298
2767 -> Loss: 0.43723106384277344
2768 -> Loss: 0.6890143156051636
2769 -> Loss: 0.5693326592445374
2770 -> Loss: 0.512026309967041
2771 -> Loss: 0.6937447190284729
2772 -> Loss: 0.31797298789024353
2773 -> Loss: 0.33162617683410645
2774 -> Loss: 0.8420431017875671
2775 -> Loss: 0.5494641661643982
2776 -> Loss: 0.6599085330963135
2777 -> Loss: 0.3318813443183899
2778 -> Loss: 0.2893619239330292
2779 -> Loss: 0.43129104375839233
27

2997 -> Loss: 0.2953836917877197
2998 -> Loss: 0.4084415137767792
2999 -> Loss: 0.1472795009613037
3000 -> Loss: 0.6400266885757446

Validation Accuracy: 66.0, Test Accuracy: 69.0 

3001 -> Loss: 0.5581330060958862
3002 -> Loss: 0.23533594608306885
3003 -> Loss: 0.6867280006408691
3004 -> Loss: 0.5135515928268433
3005 -> Loss: 0.9730560779571533
3006 -> Loss: 0.46886277198791504
3007 -> Loss: 1.2490272521972656
3008 -> Loss: 0.5014830231666565
3009 -> Loss: 0.25987911224365234
3010 -> Loss: 0.5308150053024292
3011 -> Loss: 0.4832765460014343
3012 -> Loss: 0.6552945971488953
3013 -> Loss: 0.3839279115200043
3014 -> Loss: 0.2182316780090332
3015 -> Loss: 0.5197250843048096
3016 -> Loss: 0.36965394020080566
3017 -> Loss: 0.6047237515449524
3018 -> Loss: 0.5339654088020325
3019 -> Loss: 0.29821810126304626
3020 -> Loss: 0.31071415543556213
3021 -> Loss: 1.0663871765136719
3022 -> Loss: 0.48596853017807007
3023 -> Loss: 0.213052898645401
3024 -> Loss: 0.34535348415374756
3025 -> Loss: 0.535

3242 -> Loss: 0.7791478037834167
3243 -> Loss: 0.20682628452777863
3244 -> Loss: 0.56416255235672
3245 -> Loss: 0.31709474325180054
3246 -> Loss: 0.5480600595474243
3247 -> Loss: 0.6978244781494141
3248 -> Loss: 0.6663452386856079
3249 -> Loss: 0.4410147964954376
3250 -> Loss: 0.6533072590827942
3251 -> Loss: 0.746029257774353
3252 -> Loss: 0.6160250902175903
3253 -> Loss: 0.6651992201805115
3254 -> Loss: 0.21638745069503784
3255 -> Loss: 0.49191635847091675
3256 -> Loss: 0.2883581817150116
3257 -> Loss: 0.36123549938201904
3258 -> Loss: 0.8632410764694214
3259 -> Loss: 0.46077194809913635
3260 -> Loss: 0.9752052426338196
3261 -> Loss: 0.4846882224082947
3262 -> Loss: 0.3124257028102875
3263 -> Loss: 0.3093239367008209
3264 -> Loss: 0.5019283294677734
3265 -> Loss: 0.4434535503387451
3266 -> Loss: 0.5530391335487366
3267 -> Loss: 0.6261914968490601
3268 -> Loss: 0.5550099611282349
3269 -> Loss: 0.31767648458480835
3270 -> Loss: 0.29302558302879333
3271 -> Loss: 0.23900045454502106
3272

3489 -> Loss: 0.49097028374671936
3490 -> Loss: 0.9671489596366882
3491 -> Loss: 0.6158559322357178
3492 -> Loss: 0.5108720064163208
3493 -> Loss: 0.8883898854255676
3494 -> Loss: 0.8236611485481262
3495 -> Loss: 0.33331841230392456
3496 -> Loss: 0.39126530289649963
3497 -> Loss: 0.41070684790611267
3498 -> Loss: 0.4509291350841522
3499 -> Loss: 0.5128753185272217
3500 -> Loss: 0.08263424038887024
3501 -> Loss: 0.3150462210178375
3502 -> Loss: 0.40181106328964233
3503 -> Loss: 0.5398712158203125
3504 -> Loss: 0.4579780697822571
3505 -> Loss: 0.40026459097862244
3506 -> Loss: 0.6796720027923584
3507 -> Loss: 0.4406299889087677
3508 -> Loss: 0.28926408290863037
3509 -> Loss: 0.4215927720069885
3510 -> Loss: 0.7705248594284058
3511 -> Loss: 0.23431746661663055
3512 -> Loss: 0.8801308870315552
3513 -> Loss: 0.5623769760131836
3514 -> Loss: 0.9035126566886902
3515 -> Loss: 0.6095193028450012
3516 -> Loss: 0.7568979859352112
3517 -> Loss: 0.34291794896125793
3518 -> Loss: 0.6805120706558228


3736 -> Loss: 0.3432021737098694
3737 -> Loss: 0.42579054832458496
3738 -> Loss: 0.35568559169769287
3739 -> Loss: 0.292233407497406
3740 -> Loss: 0.5338556170463562
3741 -> Loss: 0.5247376561164856
3742 -> Loss: 0.9093618392944336
3743 -> Loss: 0.38599830865859985
3744 -> Loss: 0.6417112946510315
3745 -> Loss: 0.6375554800033569
3746 -> Loss: 0.40475842356681824
3747 -> Loss: 0.30372026562690735
3748 -> Loss: 0.2742557227611542
3749 -> Loss: 1.20908784866333
3750 -> Loss: 0.270112007856369
3751 -> Loss: 0.44411394000053406
3752 -> Loss: 0.8584640622138977
3753 -> Loss: 0.7182532548904419
3754 -> Loss: 0.10936320573091507
3755 -> Loss: 0.5495408773422241
3756 -> Loss: 0.38925567269325256
3757 -> Loss: 0.16524986922740936
3758 -> Loss: 0.4498109817504883
3759 -> Loss: 0.5899953842163086
3760 -> Loss: 0.34367311000823975
3761 -> Loss: 0.37960389256477356
3762 -> Loss: 0.4367661476135254
3763 -> Loss: 0.7000197172164917
3764 -> Loss: 0.3369613587856293
3765 -> Loss: 0.3144649267196655
376

3983 -> Loss: 0.28351345658302307
3984 -> Loss: 0.45766931772232056
3985 -> Loss: 0.6193770170211792
3986 -> Loss: 0.5360029339790344
3987 -> Loss: 0.2959366738796234
3988 -> Loss: 0.35712164640426636
3989 -> Loss: 0.6055095791816711
3990 -> Loss: 0.16504666209220886
3991 -> Loss: 0.6373705863952637
3992 -> Loss: 0.7779238820075989
3993 -> Loss: 0.8513798117637634
3994 -> Loss: 0.777021586894989
3995 -> Loss: 0.7404495477676392
3996 -> Loss: 0.28082290291786194
3997 -> Loss: 0.7418543100357056
3998 -> Loss: 0.8973482251167297
3999 -> Loss: 0.6786487102508545
4000 -> Loss: 0.4906977117061615

Validation Accuracy: 68.5, Test Accuracy: 73.0 

4001 -> Loss: 1.2212183475494385
4002 -> Loss: 0.381523072719574
4003 -> Loss: 0.3047163188457489
4004 -> Loss: 0.6080929040908813
4005 -> Loss: 0.5316324234008789
4006 -> Loss: 0.5097926259040833
4007 -> Loss: 0.24223873019218445
4008 -> Loss: 0.7684056162834167
4009 -> Loss: 0.4153101146221161
4010 -> Loss: 0.6498619914054871
4011 -> Loss: 0.439891

4229 -> Loss: 0.9669020175933838
4230 -> Loss: 0.4972416162490845
4231 -> Loss: 0.6770495772361755
4232 -> Loss: 0.617667555809021
4233 -> Loss: 0.7423809170722961
4234 -> Loss: 0.3555912673473358
4235 -> Loss: 0.5299884676933289
4236 -> Loss: 0.4626416862010956
4237 -> Loss: 0.6321544647216797
4238 -> Loss: 0.7998590469360352
4239 -> Loss: 0.6336766481399536
4240 -> Loss: 0.7092546224594116
4241 -> Loss: 0.17118629813194275
4242 -> Loss: 0.19824503362178802
4243 -> Loss: 0.566721498966217
4244 -> Loss: 0.45388442277908325
4245 -> Loss: 0.3000195622444153
4246 -> Loss: 0.39488354325294495
4247 -> Loss: 0.33465251326560974
4248 -> Loss: 0.6091470122337341
4249 -> Loss: 0.36915451288223267
4250 -> Loss: 0.8268702626228333
4251 -> Loss: 0.3700551986694336
4252 -> Loss: 0.8854940533638
4253 -> Loss: 0.45941463112831116
4254 -> Loss: 0.41017627716064453
4255 -> Loss: 0.5586594939231873
4256 -> Loss: 0.623632550239563
4257 -> Loss: 0.39537787437438965
4258 -> Loss: 0.13175277411937714
4259 -

4476 -> Loss: 0.5878736972808838
4477 -> Loss: 0.4675315320491791
4478 -> Loss: 0.5279601216316223
4479 -> Loss: 1.0035465955734253
4480 -> Loss: 0.6351914405822754
4481 -> Loss: 0.4929311275482178
4482 -> Loss: 0.11803578585386276
4483 -> Loss: 0.6235116720199585
4484 -> Loss: 0.4695885181427002
4485 -> Loss: 0.41513553261756897
4486 -> Loss: 0.36554983258247375
4487 -> Loss: 0.26845353841781616
4488 -> Loss: 0.641502320766449
4489 -> Loss: 0.5997942090034485
4490 -> Loss: 0.45004355907440186
4491 -> Loss: 0.904541015625
4492 -> Loss: 0.5891513228416443
4493 -> Loss: 0.35151320695877075
4494 -> Loss: 0.24684825539588928
4495 -> Loss: 0.5392006635665894
4496 -> Loss: 0.8868497610092163
4497 -> Loss: 0.6278976798057556
4498 -> Loss: 0.41750213503837585
4499 -> Loss: 0.7766048312187195
4500 -> Loss: 0.5278530120849609
4501 -> Loss: 0.8324658870697021
4502 -> Loss: 0.1490786075592041
4503 -> Loss: 0.41637587547302246
4504 -> Loss: 0.04900688678026199
4505 -> Loss: 0.654534101486206
4506 -

100%|█████████████████████████████████████████████| 1/1 [1:35:05<00:00, 5705.46s/it]

Model Saved At:  0





In [None]:
model.eval()

In [None]:
index = 1

In [None]:
example = val_dataset_object[index]
print(example.keys())

In [None]:
delLab = example.pop('labels')

In [None]:
processor.decode(example['input_ids'])

In [None]:
# add batch dimension + move to GPU
example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}

# forward pass
outputs = model(**example)

In [None]:
logits = outputs.logits
logits.shape

In [None]:
logits

In [None]:
print(torch.argmax(logits).item())
reverse_mapping[logits.argmax(-1).item()]

In [None]:
answerList[index]

In [None]:
predicted_classes = torch.sigmoid(logits)
probs, classes = torch.topk(predicted_classes, 5)

for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
  print(prob, model.config.id2label[class_idx])

In [None]:
i = Image.open(imgPathList[index])
i.thumbnail((300,300))
i

## Reports & Results

In [106]:
# This function returns the Validation Loss and accuracy on the Validation Set

def calculateAccuracyVal():
    
    matchScore, loopCounter = 0,0
    
    for index in range(0,200):
        
        loopCounter += 1
        
        val_example = val_dataset_object[index]
        val_example = {k: v.unsqueeze(0).to(device) for k,v in val_example.items()}
        val_outputs = model(**val_example)
        
        validationLoss = val_outputs.loss

        val_logits = val_outputs.logits
        val_predicted_classes = torch.sigmoid(val_logits)
        val_ans = reverse_mapping[torch.argmax(val_predicted_classes).item()]
        
        
        # accuracy score
        
        if answerList_val[index] == val_ans:
            matchScore += 1
                
    #print(matchScore, loopCounter)
    accuracyVal = (matchScore/loopCounter)*100
    return ( accuracyVal,validationLoss.item() )

In [None]:
calculateAccuracyVal()

In [107]:
# This function returns accuracy on the Test Set

def calculateAccuracyTest():
    
    matchScore, loopCounter = 0,0
    model.eval()
    for index in range(0, 200):
        
        loopCounter += 1
        
        test_example = test_dataset_object[index]
        test_example = {k: v.unsqueeze(0).to(device) for k,v in test_example.items()}
        test_outputs = model(**test_example)

        test_logits = test_outputs.logits
        test_predicted_classes = torch.sigmoid(test_logits)
        test_ans = reverse_mapping[torch.argmax(test_predicted_classes).item()]
        
        # print(f'T: {answerList_val[index]} <-> P: {test_ans}' )

        # accuracy score
        
        if answerList_test[index] == test_ans:
            matchScore += 1
                
    #print(matchScore, loopCounter)
    return ((matchScore/loopCounter)*100)

In [None]:
calculateAccuracyTest()

In [None]:
# This function returns report on the Test Set

misclassifiedIndex = []
def generateReport():
    
    matchScore, loopCounter = 0,0
    model.eval()
    
    for index in range(0,500):
        
        loopCounter += 1
        print(f'\n{questionList_test[index]} ? Ans: {answerList_test[index]}\n')
        
        example = test_dataset_object[index]
        example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}
        outputs = model(**example)

        logits = outputs.logits
        predicted_classes = torch.sigmoid(logits)
        ans = reverse_mapping[torch.argmax(predicted_classes).item()]
        
        print('Predicted Ans:', ans,'\n')
        
        probs, classes = torch.topk(predicted_classes, 4)

        for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
            print(prob, model.config.id2label[class_idx])
    
        # accuracy score
        
        if answerList_test[index] == ans:
            matchScore += 1
            print('Correct Prediction at index:', index)
        
        else:
            misclassifiedIndex.append(index)
            print('Wrong Prediction at index:', index)
    
    return ((matchScore/loopCounter)*100)

In [None]:
generateReport()

In [None]:
misclassifiedIndex

In [None]:
i = Image.open(imgPathList_test[53054])
i.thumbnail((500,500))
i

## Find Questions With Color And Its Accuracy

In [None]:
# extracting the list of colors from the previously stored text files

colors = []
with open('./text_files/colors.txt', 'r') as file:
    for color in file:
        color = color.strip()
        colors.append(color)

In [None]:
colors[0:5]

In [None]:
# adding leading and trailing space in the colors

colors_spaces = [' '+ color + ' ' for color in colors] 

In [None]:
colors_spaces[0:5]

In [None]:
def isContainColor(targetString, colorList):
    
    for color in colorList:
        if color in targetString:
            return True
    
    return False    

In [None]:
# This function finds the accuracy on color question and identifies the color question for which the result is misclassified

misclassifiedIndex = []
colorFrequency = {color: 0 for color in colors}

def findColorQuestions():
    
    global colors
    matchScore, questionCount = 0,0
    model.eval()
    
    print('***** Question About Colors ************')
    
    for index in tqdm(range(1000,2000)):
        
        currQuestion = questionList_test[index]        
        
        if ('color' in currQuestion) or (isContainColor(currQuestion, colors_spaces)):
            
            questionCount += 1
            
            #print(f'\n{questionList_test[index]} ? Ans: {answerList_test[index]}\n')
            
            example = test_dataset_object[index]
            example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}
            outputs = model(**example)

            logits = outputs.logits
            predicted_classes = torch.sigmoid(logits)
            ans = reverse_mapping[torch.argmax(predicted_classes).item()]

            # accuracy score

            if answerList_test[index] == ans:
                matchScore += 1

            else:
                
                if answerList_test[index] in colors:
                    colorFrequency[ answerList_test[index] ] = colorFrequency[ answerList_test[index] ] + 1
                    
                misclassifiedIndex.append(index)
                                
        else:
            
            continue
    
    
    print(f'\nTotal {questionCount} questions found')
    return ((matchScore/questionCount)*100)

In [None]:
findColorQuestions()

In [None]:
len(misclassifiedIndex)

In [None]:
colorFrequency

In [None]:
import matplotlib.pyplot as plt

In [None]:
categories = colorFrequency.keys()
values = colorFrequency.values()

In [None]:
plt.figure(figsize=(14, 7))

plt.xlabel(f"Model Misclassified Total {sum(values)} Questions involving Colors")
plt.ylabel("Values")
plt.title("Number of times Ground truth Colors which has been misclassified by the model")

plt.bar(colorFrequency.keys(), colorFrequency.values(), color='lightblue', edgecolor='black', width=0.4)
plt.grid(True)

# Show the plot
plt.show()

## Store Color Question Indices

In [None]:
# extracting the list of colors from the previously stored text files

colors_train = []
with open('./text_files/colors_train.txt', 'r') as file:
    for color in file:
        color = color.replace("\n","")
        colors_train.append(color)

In [None]:
# This function collects the indices of the color questions from the train set

colorQuestionIndices = []

def storeColorQuestionIndex():
    
    questionCount = 0 
    print('********* Storing Color Questions Indices ************')
    
    for index in tqdm(range(0,len(answerList))):
        
        currAnswer = answerList[index]  
        currQuestion = questionList[index]
                
        if ('color' in currQuestion) or (isContainColor(currQuestion,colors_train)):
            #print(index,currQuestion)

            questionCount += 1
            colorQuestionIndices.append(index)
        
                
    print(f'\nTotal {questionCount} color questions found')

In [None]:
storeColorQuestionIndex()

In [None]:
len(colorQuestionIndices)

In [None]:
colorQuestionIndices[1000:1005]

In [None]:
questionList[2560]

## Words in color Questions Histogram

In [None]:
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

In [None]:
def remove_stopwords(wordList):
    stop_words = set(stopwords.words('english'))
    filtered_words = [word for word in wordList if word.lower() not in stop_words]
    return filtered_words

In [None]:
# this function gathers all the words from the color questions and their frequency from all the color questions to make histogram

frquencyMap = {}

def collectWords():
    
    for index in tqdm(misclassifiedIndex):
        
        currQuestion = questionList_test[index]
        words = remove_stopwords(currQuestion.split())
        
        for word in words:
            
            if word in frquencyMap:
                
                frquencyMap[word] = frquencyMap[word] + 1
            
            else:
            
                frquencyMap[word] = 1


In [None]:
collectWords()

In [None]:
len(frquencyMap)

In [None]:
# convert to list of tuples

frquencyList = [(key,val) for key,val in frquencyMap.items()]

In [None]:
frquencyList[0:5]

In [None]:
frquencyList.sort(key = lambda x:x[1], reverse = True)

In [None]:
frquencyList[0:5]

In [None]:
import seaborn as sns

In [None]:
# top-30 words

frquencyList = frquencyList[0:30]
labels = [ val[0] for val in frquencyList]
frequncies = [ val[1] for val in frquencyList]

In [None]:
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
ax = sns.barplot(x=frequncies, y=labels, palette="viridis")

# Customize the plot
ax.set(xlabel="Frequency", ylabel="Words", title="Word Frequency")
plt.tight_layout()

plt.show()

# Experimenting With Accuracy By Removing Most Frequent Words

In [None]:
len(questionList_test)

In [None]:
for index in tqdm(range(0,500)):

    currQuestion = questionList_test[index]
    if 'object' or 'used' in currQuestion:
        currQuestion = ' '.join([word for word in currQuestion.split() if word not in ('object','used')])
    
    questionList_test[index] = currQuestion

# creating HF dataset to map images fast of test_set

listToDictionary = {'questions':questionList_test, 'labels':labels_test, 'scores':scores_test, 'images':imgPathList_test}
word_removed_test_set = Dataset.from_dict(listToDictionary)

In [None]:
word_removed_test_set = word_removed_test_set.cast_column("images", datasets.Image())

In [None]:
word_removed_test_set[10]['questions'], modified_test_set[10]['questions']

In [None]:
word_removed_test_set_object = cric_dataset(word_removed_test_set, processor)

In [None]:
# This function returns report on the Test Set

misclassifiedIndex = []
def removeWordsAndGenReport():
    
    matchScore, loopCounter = 0,0
    model.eval()
    
    for index in tqdm(range(0,1000)):
        
        loopCounter += 1                            

        example = word_removed_test_set_object[index]
        example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}
        outputs = model(**example)

        logits = outputs.logits
        predicted_classes = torch.sigmoid(logits)
        ans = reverse_mapping[torch.argmax(predicted_classes).item()]
        
        # print('Predicted Ans:', ans,'\n')
        
        probs, classes = torch.topk(predicted_classes, 4)

        for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
            print(end='')
    
        # accuracy score
        
        if answerList_test[index] == ans:
            matchScore += 1
            #print('Correct Prediction at index:', index)
        
        else:
            misclassifiedIndex.append(index)
            #print('Wrong Prediction at index:', index)
    
    return ((matchScore/loopCounter)*100)

In [None]:
removeWordsAndGenReport()

In [None]:
len(questionList)

In [None]:
pwd

In [None]:
cat text_files/colors_train.txt

In [None]:
# find the color questions of the training set

# This function identifies the color question for which the result is misclassified

colorMap = {}
questionCount = 0

def findColorQuestionsTraining():
        
    print('***** Question About Colors ************')
    
    for index in tqdm(range(len(questionList))):
        
        currQuestion = questionList[index]        
        #print(currQuestion)
        
        if ('color' in currQuestion):
            
            global questionCount
            questionCount += 1
            currAnswer = answerList[index]
            
            if currAnswer in colorMap:              
                colorMap[currAnswer] = colorMap[currAnswer] + 1
            else:
                colorMap[currAnswer] = 1

            
        else:
            
            continue
            
    return (colorMap)

In [None]:
findColorQuestionsTraining()