In [2]:
!pip install wget tqdm

Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=ee732d30d21ccc3bdb7e28c06983a03e3ceba6ee89efe3af378f0b2d1bd2ea46
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, ViTImageProcessor, ViTModel, get_cosine_schedule_with_warmup
from PIL import Image
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm

# Constants
MAX_LENGTH = 128
BATCH_SIZE = 16
LEARNING_RATE = 5e-5
NUM_EPOCHS = 45
IMAGE_SIZE = 224
HIDDEN_SIZE = 1024  # Increased hidden size
DROPOUT_RATE = 0.5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 1000


In [4]:
def prepare_kaggle_daquar(input_dir, output_dir):
    """Prepare DAQUAR dataset from Kaggle format with compound answer handling"""
    print("Processing Kaggle DAQUAR dataset...")
    
    # Read CSV files
    train_data = pd.read_csv(os.path.join(input_dir, 'data_train.csv'))
    eval_data = pd.read_csv(os.path.join(input_dir, 'data_eval.csv'))
    
    # Read image lists
    with open(os.path.join(input_dir, 'train_images_list.txt'), 'r') as f:
        train_images = [line.strip() for line in f.readlines()]
    with open(os.path.join(input_dir, 'test_images_list.txt'), 'r') as f:
        test_images = [line.strip() for line in f.readlines()]
    
    # Process answers to handle compound answers
    def get_all_answers(data):
        answer_set = set()
        for answer in data['answer']:
            # Split compound answers and strip whitespace
            parts = [part.strip() for part in str(answer).split(',')]
            answer_set.update(parts)
        return sorted(list(answer_set))
    
    # Create answer vocabulary from both train and eval sets
    answer_vocab = get_all_answers(pd.concat([train_data, eval_data]))
    answer_to_idx = {ans: idx for idx, ans in enumerate(answer_vocab)}
    
    def create_annotations(data):
        annotations = []
        for _, row in data.iterrows():
            # Split compound answers into individual answers
            answers = [ans.strip() for ans in str(row['answer']).split(',')]
            # Use the first answer as the primary answer
            primary_answer = answers[0]
            
            ann = {
                'image': f"{row['image_id']}.png",
                'question': row['question'],
                'answer': primary_answer,  # Use only the primary answer
                'all_answers': answers  # Keep all answers for potential future use
            }
            annotations.append(ann)
        return annotations
    
    # Create annotation files
    train_annotations = create_annotations(train_data)
    test_annotations = create_annotations(eval_data)
    
    # Create processed directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Save processed annotations
    with open(os.path.join(output_dir, 'train_annotations.json'), 'w') as f:
        json.dump(train_annotations, f)
    with open(os.path.join(output_dir, 'test_annotations.json'), 'w') as f:
        json.dump(test_annotations, f)
    
    # Save vocabulary
    with open(os.path.join(output_dir, 'answer_vocab.json'), 'w') as f:
        json.dump({
            'answer_vocab': answer_vocab,
            'answer_to_idx': answer_to_idx
        }, f)
    
    print(f"Dataset prepared successfully!")
    print(f"Total training examples: {len(train_annotations)}")
    print(f"Total test examples: {len(test_annotations)}")
    print(f"Total unique answers: {len(answer_vocab)}")
    
    # Print first few examples
    print("\nFirst few training examples:")
    for i in range(3):
        print(f"Example {i+1}:")
        print(f"Image: {train_annotations[i]['image']}")
        print(f"Question: {train_annotations[i]['question']}")
        print(f"Primary Answer: {train_annotations[i]['answer']}")
        print(f"All Answers: {train_annotations[i]['all_answers']}\n")
    
    return len(answer_vocab)

In [5]:
class KaggleDAQUARDataset(Dataset):
    def __init__(self, input_dir, processed_dir, split='train'):
        self.input_dir = input_dir
        self.processed_dir = processed_dir
        self.split = split
        
        # Load annotations
        with open(os.path.join(processed_dir, f'{split}_annotations.json'), 'r') as f:
            self.annotations = json.load(f)
            
        # Load answer vocabulary
        with open(os.path.join(processed_dir, 'answer_vocab.json'), 'r') as f:
            vocab_data = json.load(f)
            self.answer_vocab = vocab_data['answer_vocab']
            self.answer_to_idx = vocab_data['answer_to_idx']
            
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.image_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
    
    def __len__(self):
        return len(self.annotations)
        
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # Load and preprocess image
        img_path = os.path.join(self.input_dir, 'images', ann['image'])
        image = Image.open(img_path).convert('RGB')
        
        # Process image with ViT image processor
        image_features = self.image_processor(images=image, return_tensors="pt")
        
        # Tokenize question
        question_encoding = self.tokenizer(
            ann['question'],
            padding='max_length',
            max_length=MAX_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        
        # Convert answer to index
        answer_idx = self.answer_to_idx[ann['answer']]
        
        return {
            'image': image_features.pixel_values[0],
            'input_ids': question_encoding['input_ids'][0],
            'attention_mask': question_encoding['attention_mask'][0],
            'answer': torch.tensor(answer_idx)
        }

In [6]:
    input_dir = '/kaggle/input/processed-daquar-dataset'
    output_dir = '/kaggle/working/processed'
    
    # Prepare dataset
    print("Preparing dataset...")
    num_classes = prepare_kaggle_daquar(input_dir, output_dir)
    
    print(f"\nNumber of answer classes: {num_classes}")
    
    # Test dataset loading
    print("\nTesting dataset loading...")
    try:
        dataset = KaggleDAQUARDataset(
            input_dir=input_dir,
            processed_dir=output_dir,
            split='train'
        )
        print(f"Successfully loaded dataset with {len(dataset)} examples")
        
        # Test loading first item
        first_item = dataset[0]
        print("\nFirst item shapes:")
        print(f"Image: {first_item['image'].shape}")
        print(f"Input IDs: {first_item['input_ids'].shape}")
        print(f"Attention Mask: {first_item['attention_mask'].shape}")
        print(f"Answer: {first_item['answer']}")
        
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")

Preparing dataset...
Processing Kaggle DAQUAR dataset...
Dataset prepared successfully!
Total training examples: 6795
Total test examples: 5673
Total unique answers: 582

First few training examples:
Example 1:
Image: image3.png
Question: what is on the right side of the black telephone and on the left side of the red chair
Primary Answer: desk
All Answers: ['desk']

Example 2:
Image: image3.png
Question: what is in front of the white door on the left side of the desk
Primary Answer: telephone
All Answers: ['telephone']

Example 3:
Image: image3.png
Question: what is on the desk
Primary Answer: book
All Answers: ['book', 'scissor', 'papers', 'tape_dispenser']


Number of answer classes: 582

Testing dataset loading...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Successfully loaded dataset with 6795 examples

First item shapes:
Image: torch.Size([3, 224, 224])
Input IDs: torch.Size([128])
Attention Mask: torch.Size([128])
Answer: 160


In [7]:
class KaggleDAQUARDataset(Dataset):
    def __init__(self, input_dir, processed_dir, split='train'):
        self.input_dir = input_dir
        self.processed_dir = processed_dir
        self.split = split
        
        with open(os.path.join(processed_dir, f'{split}_annotations.json'), 'r') as f:
            self.annotations = json.load(f)
            
        with open(os.path.join(processed_dir, 'answer_vocab.json'), 'r') as f:
            vocab_data = json.load(f)
            self.answer_vocab = vocab_data['answer_vocab']
            self.answer_to_idx = vocab_data['answer_to_idx']
            
        self.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
        self.image_processor = ViTImageProcessor.from_pretrained('google/vit-large-patch16-224')
    
    def __len__(self):
        return len(self.annotations)
        
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        img_path = os.path.join(self.input_dir, 'images', ann['image'])
        image = Image.open(img_path).convert('RGB')
        
        image_features = self.image_processor(images=image, return_tensors="pt")
        
        question_encoding = self.tokenizer(
            ann['question'],
            padding='max_length',
            max_length=MAX_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        
        answer_idx = self.answer_to_idx[ann['answer']]
        
        return {
            'image': image_features.pixel_values[0],
            'input_ids': question_encoding['input_ids'][0],
            'attention_mask': question_encoding['attention_mask'][0],
            'answer': torch.tensor(answer_idx)
        }


In [8]:
class CrossModalAttention(nn.Module):
    def __init__(self, hidden_size, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True, dropout=dropout_rate)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        
    def forward(self, x, y):
        attended_x, _ = self.attention(x, y, y)
        x = self.norm1(x + self.dropout1(attended_x))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        return x


In [9]:
class MultimodalVQAModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.fusion = nn.Sequential(
            nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE),
            nn.LayerNorm(HIDDEN_SIZE),
            nn.GELU(),
            nn.Dropout(DROPOUT_RATE),
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
            nn.LayerNorm(HIDDEN_SIZE),
            nn.GELU(),
            nn.Dropout(DROPOUT_RATE)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE // 2),
            nn.LayerNorm(HIDDEN_SIZE // 2),
            nn.GELU(),
            nn.Dropout(DROPOUT_RATE),
            nn.Linear(HIDDEN_SIZE // 2, num_classes)
        )
        
        self.bert = BertModel.from_pretrained('bert-large-uncased')
        self.vit = ViTModel.from_pretrained('google/vit-large-patch16-224')
        
        self.image_to_text_attention = CrossModalAttention(HIDDEN_SIZE)
        self.text_to_image_attention = CrossModalAttention(HIDDEN_SIZE)
        
        self._initialize_weights()
        self.unfreeze_layers(4)
    
    def _initialize_weights(self):
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        self.fusion.apply(init_weights)
        self.classifier.apply(init_weights)
        self.image_to_text_attention.apply(init_weights)
        self.text_to_image_attention.apply(init_weights)
    
    def unfreeze_layers(self, num_layers):
        for param in self.bert.parameters():
            param.requires_grad = False
        for param in self.vit.parameters():
            param.requires_grad = False
            
        for param in self.bert.encoder.layer[-num_layers:].parameters():
            param.requires_grad = True
        for param in self.vit.encoder.layer[-num_layers:].parameters():
            param.requires_grad = True
            
        for module in [self.fusion, self.classifier, self.image_to_text_attention, self.text_to_image_attention]:
            for param in module.parameters():
                param.requires_grad = True
    
    def forward(self, image, input_ids, attention_mask):
        image_features = self.vit(image).last_hidden_state
        text_features = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        
        attended_image = self.image_to_text_attention(image_features, text_features)
        attended_text = self.text_to_image_attention(text_features, image_features)
        
        image_weights = torch.softmax(attended_image.mean(-1), dim=1).unsqueeze(-1)
        text_weights = torch.softmax(attended_text.mean(-1), dim=1).unsqueeze(-1)
        
        image_pooled = (attended_image * image_weights).sum(1)
        text_pooled = (attended_text * text_weights).sum(1)
        
        combined_features = torch.cat((image_pooled, text_pooled), dim=1)
        fused_features = self.fusion(combined_features)
        
        logits = self.classifier(fused_features)
        
        return logits


In [10]:
def create_data_loaders(input_dir, processed_dir, batch_size):
    train_dataset = KaggleDAQUARDataset(input_dir=input_dir, processed_dir=processed_dir, split='train')
    val_dataset = KaggleDAQUARDataset(input_dir=input_dir, processed_dir=processed_dir, split='test')
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    return train_loader, val_loader


In [11]:
def train_model(input_dir, processed_dir, num_classes):
    train_loader, val_loader = create_data_loaders(input_dir, processed_dir, BATCH_SIZE)
    
    model = MultimodalVQAModel(num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    num_training_steps = len(train_loader) * NUM_EPOCHS
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    best_val_acc = 0
    patience = 5
    patience_counter = 0
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
        for batch in progress_bar:
            image = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            answers = batch['answer'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(image, input_ids, attention_mask)
            loss = criterion(outputs, answers)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += answers.size(0)
            correct += predicted.eq(answers).sum().item()
            
            progress_bar.set_postfix({
                'loss': f'{train_loss/total:.4f}',
                'acc': f'{100.*correct/total:.2f}%',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })
        
        val_acc = validate_model(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'\nEarly stopping triggered after {epoch+1} epochs')
                break
    
    model.load_state_dict(torch.load('best_model.pth'))
    return model


In [12]:
def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            image = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            answers = batch['answer'].to(device)
            
            outputs = model(image, input_ids, attention_mask)
            loss = criterion(outputs, answers)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += answers.size(0)
            val_correct += predicted.eq(answers).sum().item()
    
    val_acc = 100. * val_correct / val_total
    print(f'\nValidation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {val_acc:.2f}%')
    return val_acc


In [13]:
input_dir = '/kaggle/input/processed-daquar-dataset'
processed_dir = '/kaggle/working/processed'
num_classes = prepare_kaggle_daquar(input_dir, processed_dir)
model = train_model(input_dir, processed_dir, num_classes)


Processing Kaggle DAQUAR dataset...
Dataset prepared successfully!
Total training examples: 6795
Total test examples: 5673
Total unique answers: 582

First few training examples:
Example 1:
Image: image3.png
Question: what is on the right side of the black telephone and on the left side of the red chair
Primary Answer: desk
All Answers: ['desk']

Example 2:
Image: image3.png
Question: what is in front of the white door on the left side of the desk
Primary Answer: telephone
All Answers: ['telephone']

Example 3:
Image: image3.png
Question: what is on the desk
Primary Answer: book
All Answers: ['book', 'scissor', 'papers', 'tape_dispenser']



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/45: 100%|██████████| 425/425 [05:35<00:00,  1.26it/s, loss=0.4154, acc=1.00%, lr=2.13e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 5.2811, Accuracy: 7.42%


Epoch 2/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.3497, acc=5.74%, lr=4.25e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.85it/s]



Validation Loss: 4.5243, Accuracy: 13.34%


Epoch 3/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.3078, acc=9.27%, lr=5.00e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.2027, Accuracy: 19.73%


Epoch 4/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.2857, acc=12.72%, lr=4.98e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.0089, Accuracy: 22.18%


Epoch 5/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.2689, acc=15.82%, lr=4.95e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.8408, Accuracy: 24.08%


Epoch 6/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.2567, acc=17.98%, lr=4.91e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.8254, Accuracy: 24.80%


Epoch 7/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2490, acc=19.00%, lr=4.85e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7194, Accuracy: 24.73%


Epoch 8/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2392, acc=22.03%, lr=4.79e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7323, Accuracy: 26.05%


Epoch 9/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2283, acc=24.81%, lr=4.71e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.6219, Accuracy: 26.97%


Epoch 10/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2207, acc=27.21%, lr=4.61e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7565, Accuracy: 27.66%


Epoch 11/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2111, acc=29.96%, lr=4.51e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.6592, Accuracy: 27.76%


Epoch 12/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.2016, acc=33.05%, lr=4.39e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7222, Accuracy: 28.57%


Epoch 13/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1920, acc=35.66%, lr=4.27e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7458, Accuracy: 27.62%


Epoch 14/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1841, acc=38.31%, lr=4.13e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.7382, Accuracy: 28.34%


Epoch 15/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1747, acc=40.78%, lr=3.99e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.8038, Accuracy: 29.46%


Epoch 16/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1666, acc=44.30%, lr=3.84e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.8871, Accuracy: 28.75%


Epoch 17/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1577, acc=47.11%, lr=3.68e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 3.8804, Accuracy: 29.49%


Epoch 18/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1481, acc=49.71%, lr=3.52e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.0041, Accuracy: 29.91%


Epoch 19/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1404, acc=52.48%, lr=3.34e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.0787, Accuracy: 29.93%


Epoch 20/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1332, acc=55.01%, lr=3.17e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.1244, Accuracy: 29.81%


Epoch 21/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1246, acc=58.06%, lr=2.99e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.1508, Accuracy: 29.14%


Epoch 22/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1175, acc=60.32%, lr=2.81e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.1996, Accuracy: 29.63%


Epoch 23/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1101, acc=63.19%, lr=2.62e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.3124, Accuracy: 29.75%


Epoch 24/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.1023, acc=66.20%, lr=2.44e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.3256, Accuracy: 30.00%


Epoch 25/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0983, acc=67.21%, lr=2.26e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.4163, Accuracy: 29.95%


Epoch 26/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0921, acc=69.32%, lr=2.07e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.5295, Accuracy: 29.93%


Epoch 27/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0875, acc=71.45%, lr=1.89e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.5463, Accuracy: 30.34%


Epoch 28/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0827, acc=73.27%, lr=1.72e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.5827, Accuracy: 30.50%


Epoch 29/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0785, acc=74.14%, lr=1.54e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.85it/s]



Validation Loss: 4.6331, Accuracy: 30.53%


Epoch 30/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.0763, acc=75.17%, lr=1.38e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.6652, Accuracy: 30.37%


Epoch 31/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.0710, acc=76.59%, lr=1.22e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.7288, Accuracy: 30.27%


Epoch 32/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.0685, acc=77.84%, lr=1.06e-05]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.7402, Accuracy: 30.67%


Epoch 33/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0655, acc=79.19%, lr=9.15e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.7959, Accuracy: 30.90%


Epoch 34/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0639, acc=79.44%, lr=7.77e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.8101, Accuracy: 30.92%


Epoch 35/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0616, acc=80.38%, lr=6.48e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.8450, Accuracy: 30.69%


Epoch 36/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0602, acc=81.31%, lr=5.30e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.8625, Accuracy: 30.83%


Epoch 37/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0585, acc=81.82%, lr=4.22e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.9000, Accuracy: 30.69%


Epoch 38/45: 100%|██████████| 425/425 [05:34<00:00,  1.27it/s, loss=0.0575, acc=82.00%, lr=3.25e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]



Validation Loss: 4.9016, Accuracy: 30.92%


Epoch 39/45: 100%|██████████| 425/425 [05:35<00:00,  1.27it/s, loss=0.0566, acc=82.62%, lr=2.40e-06]
Validation: 100%|██████████| 355/355 [03:12<00:00,  1.84it/s]
  model.load_state_dict(torch.load('best_model.pth'))



Validation Loss: 4.9145, Accuracy: 30.81%

Early stopping triggered after 39 epochs
