In [1]:
!pip install transformers torch torchvision scikit-learn pandas pillow tqdm openpyxl


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
!pip install ftfy


Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1


In [None]:


import sys
sys.path.append("/kaggle/input/dataset5")

import os
import argparse
from pathlib import Path
import pandas as pd
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
import torchvision.models as models
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from normalizer import normalize


import timm
from sentence_transformers import SentenceTransformer


def clean_text(text):
    if not text or pd.isna(text):
        return ""
    text = str(text)
    
    text = text.strip()
    return text


class MemeDataset(Dataset):
    def __init__(self, df, images_dir, tokenizer, max_length=128, image_size=224,
                 use_normalizer=True):
        self.df = df.reset_index(drop=True)
        self.images_dir = Path(images_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_size = image_size
        self.use_normalizer = use_normalizer

        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_path = self.images_dir / row['image_file_name']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception:
            # fallback black image
            img = Image.new('RGB', (self.image_size, self.image_size), color=(0, 0, 0))
        img = self.transform(img)

        text = clean_text(row.get('text', ""))

        if self.use_normalizer and text:
            try:
                text = normalize(text)
            except Exception:
                pass

        
        tok = self.tokenizer(text,
                             truncation=True,
                             padding='max_length',
                             max_length=self.max_length,
                             return_tensors='pt')
        input_ids = tok['input_ids'].squeeze(0)
        attention_mask = tok['attention_mask'].squeeze(0)

        label = int(row['label'])
        return {
            'image': img,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label, dtype=torch.long),
            'text': text
        }


def collate_fn(batch):
    images = torch.stack([b['image'] for b in batch])
    input_ids = torch.stack([b['input_ids'] for b in batch])
    attention_mask = torch.stack([b['attention_mask'] for b in batch])
    labels = torch.stack([b['label'] for b in batch])
    texts = [b['text'] for b in batch]
    return {
        'image': images,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'texts': texts
    }


class MultimodalClassifier(nn.Module):
    """
    ViT (timm) + SentenceTransformer multimodal classifier.
    If `text_model` is a sentence-transformers name -> use SentenceTransformer.
    Otherwise it will try to load an AutoModel (e.g., BanglishBERT).
    """

    def __init__(self, text_model: str, text_feat_dim: int, num_labels: int,
                 freeze_text=False, freeze_image=False):
        super().__init__()

        self.text_model_name = text_model
        self.use_sentence_transformer = "sentence-transformers" in text_model

        # TEXT ENCODER
        if self.use_sentence_transformer:
            # SentenceTransformer loads a model and provides .encode(...)
            self.text_encoder = SentenceTransformer(text_model)
            raw_text_dim = self.text_encoder.get_sentence_embedding_dimension()
        else:
            # fallback to huggingface AutoModel
            self.text_encoder = AutoModel.from_pretrained(text_model)
            raw_text_dim = self.text_encoder.config.hidden_size

        if freeze_text:
            for p in self.text_encoder.parameters():
                p.requires_grad = False

        # Project text embedding into a 512-dim common space
        self.text_proj = nn.Linear(raw_text_dim, 512)

        # IMAGE ENCODER (ViT via timm). Remove classification head (num_classes=0)
        self.image_encoder = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        image_feat_dim = self.image_encoder.num_features if hasattr(self.image_encoder, "num_features") else 768
        if freeze_image:
            for p in self.image_encoder.parameters():
                p.requires_grad = False
        self.image_proj = nn.Linear(image_feat_dim, 512)

        # classifier head
        self.classifier = nn.Sequential(
            nn.Linear(512 + 512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_labels)
        )

    def forward(self, images, texts=None, input_ids=None, attention_mask=None):
        # images -> [B, 3, H, W]
        img_feat = self.image_encoder(images)           
        img_feat = self.image_proj(img_feat)            

        # text encoding:
        if self.use_sentence_transformer:
            
            if texts is None:
                raise ValueError("texts must be provided when using SentenceTransformer as text_model")
            
            txt_feat = self.text_encoder.encode(texts, convert_to_tensor=True)
            if isinstance(txt_feat, torch.Tensor):
                txt_feat = txt_feat.to(img_feat.device)
            else:
                
                txt_feat = torch.tensor(np.asarray(txt_feat), dtype=torch.float32, device=img_feat.device)
        else:
            
            assert input_ids is not None and attention_mask is not None, "input_ids/attention_mask required for AutoModel text encoder"
            text_out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                txt_feat = text_out.pooler_output
            else:
                last_hidden = text_out.last_hidden_state
                mask = attention_mask.unsqueeze(-1).float()
                summed = (last_hidden * mask).sum(1)
                denom = mask.sum(1).clamp(min=1e-9)
                txt_feat = summed / denom

        txt_feat = self.text_proj(txt_feat)            # (B, 512)

        fused = torch.cat([img_feat, txt_feat], dim=1)  # (B, 1024)
        logits = self.classifier(fused)
        return logits


def find_discrepancies(df, images_dir):
    images_dir = Path(images_dir)
    referenced = set(df['image_file_name'].astype(str).tolist())
    actual = set([p.name for p in images_dir.glob('*') if p.is_file()])
    missing = sorted(list(referenced - actual))
    orphan = sorted(list(actual - referenced))
    return missing, orphan


def prepare_dataframe(path, images_dir, drop_label_value=2):
    df = pd.read_excel(path)
    assert 'image_file_name' in df.columns and 'text' in df.columns and 'label' in df.columns, \
        "metadata.xlsx must contain columns: image_file_name, text, label"

    df = df[df['label'] != drop_label_value].copy()
    df['image_file_name'] = df['image_file_name'].astype(str).str.strip()

    missing, orphan = find_discrepancies(df, images_dir)
    if missing:
        print(f"Missing images for {len(missing)} metadata entries")
        df = df[~df['image_file_name'].isin(missing)].copy()

    if orphan:
        print(f"Found {len(orphan)} orphan image files not in metadata (showing up to 20):")
        for o in orphan[:20]:
            print("  -", o)
        if len(orphan) > 20:
            print("  ... and more")

    unique_labels = sorted(df['label'].unique().tolist())
    label_map = {orig: idx for idx, orig in enumerate(unique_labels)}
    df['label'] = df['label'].map(label_map)
    print("Label mapping:", label_map)
    return df, orphan, label_map


def compute_class_weights(df, power=0.5):
    counts = df['label'].value_counts().sort_index().values
    weights = (1.0 / counts) ** power
    weights = weights / weights.sum() * len(weights)
    sample_weights = df['label'].map(lambda x: weights[x]).values
    return sample_weights


def train_one_epoch(model, dataloader, optimizer, criterion, device, scheduler=None):
    model.train()
    total_loss = 0.0
    for batch in tqdm(dataloader, desc="Train", leave=False):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        texts = batch['texts']

        optimizer.zero_grad()
        logits = model(images=images, texts=texts, input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item() * images.size(0)
    return total_loss / len(dataloader.dataset)


@torch.no_grad()
def evaluate(model, dataloader, device, label_map):
    model.eval()
    preds = []
    trues = []
    for batch in tqdm(dataloader, desc="Eval", leave=False):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        texts = batch['texts']

        logits = model(images=images, texts=texts, input_ids=input_ids, attention_mask=attention_mask)
        batch_preds = torch.argmax(logits, dim=1).cpu().numpy().tolist()
        batch_trues = labels.cpu().numpy().tolist()
        preds.extend(batch_preds)
        trues.extend(batch_trues)

    acc = accuracy_score(trues, preds)
    report = classification_report(trues, preds, digits=4)
    cm = confusion_matrix(trues, preds)
    return acc, report, trues, preds, cm


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    df, orphan_files, label_map = prepare_dataframe(args.data, args.images_dir, drop_label_value=2)

    if args.delete_orphans and orphan_files:
        for fname in orphan_files:
            p = Path(args.images_dir) / fname
            try:
                p.unlink()
            except Exception as e:
                print("Could not delete:", p, e)
        print("Deleted orphans.")

    train_df, test_df = train_test_split(df, test_size=args.test_size, stratify=df['label'], random_state=42)
    train_df, val_df = train_test_split(train_df, test_size=args.val_size, stratify=train_df['label'], random_state=42)
    print(f"Train / Val / Test sizes: {len(train_df)} / {len(val_df)} / {len(test_df)}")

    tokenizer = AutoTokenizer.from_pretrained(args.text_model)

    use_normalizer = not args.disable_normalizer
    print(f"Text normalization: {'enabled' if use_normalizer else 'disabled'}")

    train_dataset = MemeDataset(train_df, args.images_dir, tokenizer,
                                max_length=args.max_length, image_size=args.image_size,
                                use_normalizer=use_normalizer)
    val_dataset = MemeDataset(val_df, args.images_dir, tokenizer,
                              max_length=args.max_length, image_size=args.image_size,
                              use_normalizer=use_normalizer)
    test_dataset = MemeDataset(test_df, args.images_dir, tokenizer,
                               max_length=args.max_length, image_size=args.image_size,
                               use_normalizer=use_normalizer)

    sample_weights = compute_class_weights(train_df, power=args.weight_power)
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_dataset), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler,
                              collate_fn=collate_fn, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            collate_fn=collate_fn, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                             collate_fn=collate_fn, num_workers=0, pin_memory=True)

    num_labels = len(label_map)
    model = MultimodalClassifier(text_model=args.text_model,
                                 text_feat_dim=args.text_feat_dim,
                                 num_labels=num_labels,
                                 freeze_text=args.freeze_text,
                                 freeze_image=args.freeze_image)
    model.to(device)

    # loss, optimizer, scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)

    total_steps = len(train_loader) * args.epochs
    warmup_steps = int(0.1 * total_steps) if total_steps > 0 else 0

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_val_acc = 0.0
    os.makedirs(args.out_dir, exist_ok=True)

    print("\nStarting training...")
    for epoch in range(1, args.epochs + 1):
        print(f"\nEpoch {epoch}/{args.epochs}")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scheduler)
        print(f"Train loss: {train_loss:.6f}")

        val_acc, val_report, _, _, val_cm = evaluate(model, val_loader, device, label_map)
        print(f"Validation Acc: {val_acc:.4f}")

        print("Validation classification report:")
        print(val_report)
        print("Confusion Matrix:")
        print(val_cm)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'label_map': label_map,
                'epoch': epoch,
                'val_acc': val_acc
            }, os.path.join(args.out_dir, "best_model.pt"))
            print(f"✓ Saved best model (val_acc: {val_acc:.4f})")

    # final evaluation on test
    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)
    ckpt_path = os.path.join(args.out_dir, "best_model.pt")
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        print(f"Loaded model from epoch {ckpt.get('epoch', '?')} with val_acc: {ckpt.get('val_acc', 0):.4f}")
    else:
        print("No checkpoint found, using current model weights.")

    test_acc, test_report, trues, preds, test_cm = evaluate(model, test_loader, device, label_map)
    print(f"\n{'='*60}")
    print(f"TEST ACCURACY: {test_acc:.4f} ({test_acc*100:.2f}%)")
    print(f"{'='*60}")
    print("\nTest Classification Report:")
    print(test_report)
    print("\nTest Confusion Matrix:")
    print(test_cm)

    out = test_df.reset_index(drop=True).copy()
    out['pred_idx'] = preds
    inv_map = {v: k for k, v in label_map.items()}
    out['pred_orig'] = out['pred_idx'].map(inv_map)
    out['true_orig'] = out['label'].map(inv_map)
    out.to_csv(os.path.join(args.out_dir, "test_predictions.csv"), index=False)
    print(f"\nResults saved to {args.out_dir}/")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='/kaggle/input/dataset5/metadata.xlsx')
    parser.add_argument('--images_dir', type=str, default='/kaggle/input/dataset5/images')
    parser.add_argument('--out_dir', type=str, default='/kaggle/working/output')
    parser.add_argument('--epochs', type=int, default=10)            
    parser.add_argument('--batch_size', type=int, default=10)        
    parser.add_argument('--lr', type=float, default=2e-5)
    parser.add_argument('--weight_decay', type=float, default=0.005)
    
    parser.add_argument('--text_model', type=str,
                        default='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
    parser.add_argument('--max_length', type=int, default=64)
    parser.add_argument('--image_size', type=int, default=224)
    parser.add_argument('--val_size', type=float, default=0.1)
    parser.add_argument('--test_size', type=float, default=0.1)
    parser.add_argument('--text_feat_dim', type=int, default=768)   
    parser.add_argument('--hidden_dim', type=int, default=512)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--patience', type=int, default=7)
    parser.add_argument('--weight_power', type=float, default=0.5)
    parser.add_argument('--freeze_text', action='store_true')
    parser.add_argument('--freeze_image', action='store_true')
    parser.add_argument('--delete-orphans', action='store_true')
    parser.add_argument('--disable-normalizer', action='store_true')
    parser.add_argument('--augment', action='store_true')
    parser.add_argument('--verbose', action='store_true')

    
    args = parser.parse_args([])
    main(args)


2025-11-24 05:21:06.632183: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763961666.843308      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763961666.902128      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Device: cuda
Found 4 orphan image files not in metadata (showing up to 20):
  - FB_IMG_1751540473613.jpg
  - FB_IMG_1751739942837.jpg
  - FB_IMG_1754929300743.jpg
  - FB_IMG_1755921270397.jpg
Label mapping: {0: 0, 1: 1, 3: 2}
Train / Val / Test sizes: 5508 / 612 / 680


tokenizer_config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Text normalization: enabled


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/471M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


Starting training...

Epoch 1/10


                                                        

Train loss: 0.973601


                                                     

Validation Acc: 0.5768
Validation classification report:
              precision    recall  f1-score   support

           0     0.5333    0.3256    0.4043       172
           1     0.5645    0.2652    0.3608       132
           2     0.5888    0.8506    0.6959       308

    accuracy                         0.5768       612
   macro avg     0.5622    0.4805    0.4870       612
weighted avg     0.5680    0.5768    0.5417       612

Confusion Matrix:
[[ 56  14 102]
 [ 16  35  81]
 [ 33  13 262]]
✓ Saved best model (val_acc: 0.5768)

Epoch 2/10


                                                        

Train loss: 0.713196


                                                     

Validation Acc: 0.5703
Validation classification report:
              precision    recall  f1-score   support

           0     0.4866    0.5291    0.5070       172
           1     0.5472    0.2197    0.3135       132
           2     0.6156    0.7435    0.6735       308

    accuracy                         0.5703       612
   macro avg     0.5498    0.4974    0.4980       612
weighted avg     0.5646    0.5703    0.5491       612

Confusion Matrix:
[[ 91  11  70]
 [ 30  29  73]
 [ 66  13 229]]

Epoch 3/10


                                                        

Train loss: 0.413332


                                                     

Validation Acc: 0.5523
Validation classification report:
              precision    recall  f1-score   support

           0     0.4907    0.4593    0.4745       172
           1     0.4184    0.3106    0.3565       132
           2     0.6176    0.7078    0.6596       308

    accuracy                         0.5523       612
   macro avg     0.5089    0.4926    0.4969       612
weighted avg     0.5389    0.5523    0.5422       612

Confusion Matrix:
[[ 79  20  73]
 [ 29  41  62]
 [ 53  37 218]]

Epoch 4/10


                                                        

Train loss: 0.280361


                                                     

Validation Acc: 0.5523
Validation classification report:
              precision    recall  f1-score   support

           0     0.5676    0.3663    0.4452       172
           1     0.3876    0.3788    0.3831       132
           2     0.6048    0.7305    0.6618       308

    accuracy                         0.5523       612
   macro avg     0.5200    0.4919    0.4967       612
weighted avg     0.5475    0.5523    0.5408       612

Confusion Matrix:
[[ 63  30  79]
 [ 14  50  68]
 [ 34  49 225]]

Epoch 5/10


                                                        

Train loss: 0.189948


                                                     

Validation Acc: 0.5605
Validation classification report:
              precision    recall  f1-score   support

           0     0.4607    0.5116    0.4848       172
           1     0.5278    0.2879    0.3725       132
           2     0.6218    0.7045    0.6606       308

    accuracy                         0.5605       612
   macro avg     0.5368    0.5014    0.5060       612
weighted avg     0.5562    0.5605    0.5491       612

Confusion Matrix:
[[ 88  13  71]
 [ 33  38  61]
 [ 70  21 217]]

Epoch 6/10


                                                        

Train loss: 0.132969


                                                     

Validation Acc: 0.5507
Validation classification report:
              precision    recall  f1-score   support

           0     0.5040    0.3663    0.4242       172
           1     0.4250    0.2576    0.3208       132
           2     0.5897    0.7792    0.6713       308

    accuracy                         0.5507       612
   macro avg     0.5062    0.4677    0.4721       612
weighted avg     0.5301    0.5507    0.5263       612

Confusion Matrix:
[[ 63  19  90]
 [ 21  34  77]
 [ 41  27 240]]

Epoch 7/10


                                                        

Train loss: 0.084700


                                                     

Validation Acc: 0.5621
Validation classification report:
              precision    recall  f1-score   support

           0     0.5097    0.4593    0.4832       172
           1     0.4815    0.1970    0.2796       132
           2     0.5931    0.7760    0.6723       308

    accuracy                         0.5621       612
   macro avg     0.5281    0.4774    0.4783       612
weighted avg     0.5456    0.5621    0.5344       612

Confusion Matrix:
[[ 79   8  85]
 [ 27  26  79]
 [ 49  20 239]]

Epoch 8/10


                                                        

Train loss: 0.062970


                                                     

Validation Acc: 0.5376
Validation classification report:
              precision    recall  f1-score   support

           0     0.4897    0.4128    0.4479       172
           1     0.4040    0.3030    0.3463       132
           2     0.5924    0.7078    0.6450       308

    accuracy                         0.5376       612
   macro avg     0.4954    0.4745    0.4797       612
weighted avg     0.5229    0.5376    0.5252       612

Confusion Matrix:
[[ 71  21  80]
 [ 22  40  70]
 [ 52  38 218]]

Epoch 9/10


                                                        

Train loss: 0.040545


                                                     

Validation Acc: 0.5588
Validation classification report:
              precision    recall  f1-score   support

           0     0.5657    0.3256    0.4133       172
           1     0.4697    0.2348    0.3131       132
           2     0.5705    0.8279    0.6755       308

    accuracy                         0.5588       612
   macro avg     0.5353    0.4628    0.4673       612
weighted avg     0.5474    0.5588    0.5236       612

Confusion Matrix:
[[ 56  13 103]
 [ 12  31  89]
 [ 31  22 255]]

Epoch 10/10


                                                        

Train loss: 0.029421


                                                     

Validation Acc: 0.5539
Validation classification report:
              precision    recall  f1-score   support

           0     0.5164    0.3663    0.4286       172
           1     0.4231    0.2500    0.3143       132
           2     0.5898    0.7890    0.6750       308

    accuracy                         0.5539       612
   macro avg     0.5098    0.4684    0.4726       612
weighted avg     0.5332    0.5539    0.5279       612

Confusion Matrix:
[[ 63  20  89]
 [ 19  33  80]
 [ 40  25 243]]

FINAL EVALUATION ON TEST SET


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` or the `torch.serialization.safe_globals([scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.