### Notebook 03: Finetune Advanced Vision Models (CLIP and ViT)

As we moved towards multimodality we want to acquired some experience and intuiton on modern vision models. These moderns were trained with text + images so they will have a better performance with memes, screenshots, visual cues etc. As we saw in our EDA notebook a lot of our images have text that could be relevant to our classifications.

This notebooks covers the full preprocessing pipeline and training for Modern Vision Models (CLIP, ViT) for image classification.

This notebook trains two SEPARATE models per each achitecture: 
  - Model A: Stance Classification (support/oppose)
  - Model B: Persuasiveness Classification (yes/no)

Images are resized, normalized and batched. Models are fine-tuned with cross-entropy loss and evaluated on our goal metric - F1-Score (Binary).

Models tested:
- CLIP (OpenAI) - Trained on image-text pairs, ideal for memes
- Vision Transformer (ViT) - State-of-the-art for image classification

Intuition: Find the best vision encoder for multimodal fusion.

In [1]:
# Libraries
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm.auto import tqdm
from scipy.stats import pearsonr
from transformers import CLIPProcessor, CLIPModel,ViTImageProcessor, ViTForImageClassification,get_linear_schedule_with_warmup
import warnings

warnings.filterwarnings("ignore")


# Random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Seed:  {SEED}")
print(f"Using device: {device}")

Seed:  42
Using device: cuda


In [2]:
#Paths
DATA_PATH = "../../data/"
IMG_PATH = "../../data/images"
OUTPUT_DIR = "../../models/vision/advanced_vision_models/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

train_path = os.path.join(DATA_PATH,"train.csv")
dev_path   = os.path.join(DATA_PATH,"dev.csv")
test_path  = os.path.join(DATA_PATH,"test.csv")

#Load Data
df_train = pd.read_csv(train_path)
df_dev   = pd.read_csv(dev_path)
df_test  = pd.read_csv(test_path)

# Map labels to ints
stance_2id = {"oppose": 0, "support": 1}
pers_2id = {"no": 0, "yes": 1}

for df in [df_train, df_dev, df_test]:
    df["label"] = df["stance"].map(stance_2id)
    df["persuasiveness_label"] = df["persuasiveness"].map(pers_2id)


print(f"\n Train label distribution:")
print(f"\n Stance: \n Oppose: {(df_train['label']==0).sum()}\n Support: {(df_train['label']==1).sum()}")
print(f"\n\n  Persuasiveness \n No: {(df_train['persuasiveness_label']==0).sum()}\n Yes: {(df_train['persuasiveness_label']==1).sum()}")



 Train label distribution:

 Stance: 
 Oppose: 1095
 Support: 719


  Persuasiveness 
 No: 1285
 Yes: 529


In [3]:
# Load baseline results for comparison
baseline_results = pd.read_csv(os.path.join("../../results/vision/baseline_models/", "baseline_results.csv"))
print(f"\n Baseline Results ({baseline_results['Model'][0]}):")
print(f"Baseline F1 Stance: {baseline_results[baseline_results['Task']=='Stance']['f1'].values[0]:.4f}")
print(f"Baseline F1 Persuasiveness: {baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0]:.4f}")


 Baseline Results (RestNet50):
Baseline F1 Stance: 0.4854
Baseline F1 Persuasiveness: 0.4114


### Advanced Vision Models

In [4]:
CLIP = "openai/clip-vit-base-patch32"  # Lightweight CLIP
VIT = "google/vit-base-patch16-224"    # Standard ViT

In [5]:
#Training Hyperparameters
BATCH_SIZE = 16  # Smaller for larger models (CLIP, ViT)
NUM_EPOCHS = 10
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-4
WARMUP_RATIO = 0.1
PATIENCE = 3  # For early stopping
NUM_WORKERS = 1  # For DataLoader

In [6]:
# We create custom datasets for CLIP and ViT
class VisionDataset(Dataset):
    """
    Custom dataset that works with both CLIP and ViT processors.
    CLIP and ViT have different preprocessing requirements
    """
    
    def __init__(self, dataframe, image_dir, processor, target_task='stance'):
        """
            dataframe: DataFrame with image paths and labels
            image_dir: Directory containing images
            processor: CLIPProcessor or ViTImageProcessor
            target_task: 'stance' or 'persuasiveness'
        """
        self.df = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.processor = processor
        self.target_task = target_task
        
        # Image paths
        self.df['_image_path'] = self.df['tweet_id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        try:
            image = Image.open(row['_image_path']).convert('RGB')
        except Exception as e:
            # Fallback to blank image if loading fails
            image = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
        # Get label based on target task
        if self.target_task == 'stance':
            label = row['label']
        elif self.target_task == 'persuasiveness':
            label = row['persuasiveness_label']
        else:
            raise ValueError(f"Unknown task: {self.target_task}")
        
        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long),
            'tweet_id': row['tweet_id']
        }

def custom_collate_fn(batch, processor):
    """
    CLIP/ViT processors expect PIL images, not tensors.
    Thus, we need a custom collate function that applies the processor to the batch.
    """

    images = [item['image'] for item in batch]
    labels = torch.stack([item['label'] for item in batch])
    tweet_ids = [item['tweet_id'] for item in batch]
    
    # Process images with CLIP/ViT processor
    processed = processor(images=images, return_tensors="pt")
    
    return {
        'pixel_values': processed['pixel_values'],
        'labels': labels,
        'tweet_ids': tweet_ids
    }


### CLIP Model

Why CLIP?
- Trained on 400M image-text pairs
- Understands text IN images (perfect for memes/posters)
- Better semantic understanding than pure vision models

In [7]:

class CLIPClassifier(nn.Module):
    """
    CLIP-based classifier for stance/persuasiveness.
    
    Architecture:
      - CLIP Vision Encoder (pretrained)
      - Classification head (trainable)
    """
    
    def __init__(self, model_name="openai/clip-vit-base-patch32", num_classes=2, freeze_encoder=True):
        super(CLIPClassifier, self).__init__()
        
        # Load pretrained CLIP model
        self.clip = CLIPModel.from_pretrained(model_name)
        
        # Get vision model (we only need image encoder, not text)
        self.vision_model = self.clip.vision_model
        
        # Freeze encoder if specified (faster training, less overfitting)
        if freeze_encoder:
            for param in self.vision_model.parameters():
                param.requires_grad = False

            for name, param in self.vision_model.named_parameters():
                if "encoder.layers.11" in name:
                    param.requires_grad = True
            print("Clip encoder partially frozen (only training last layer and classifier head)")
        else:
            print("CLIP encoder not frozen (training entire model)")
        
        # Classification head
        hidden_dim = self.clip.config.projection_dim  # 512
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        print(f"CLIP model initialized: {model_name}")
        print(f"     - Hidden dim: {hidden_dim}")
        print(f"     - Num classes: {num_classes}")
        print(f"     - Encoder frozen: {freeze_encoder}")
    
    def forward(self, pixel_values):
        """
        Forward pass.
        
        Args:
            pixel_values: Preprocessed images [batch, 3, 224, 224]
        
        Returns:
            logits: [batch, num_classes]
        """
        # Get CLIP vision embeddings
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        
        # Take the CLS token from last hidden state
        cls_embedding = vision_outputs.last_hidden_state[:, 0, :]  # [batch, 768]
        
        # Project to CLIP's embedding space (512)
        pooled_output = self.clip.visual_projection(cls_embedding)  # [batch, 512]
                
        # Classification
        logits = self.classifier(pooled_output)
        
        return logits
    

# Test CLIP initialization
print("\nTesting CLIP model initialization...")
clip_processor = CLIPProcessor.from_pretrained(CLIP)
test_clip = CLIPClassifier(model_name=CLIP, num_classes=2, freeze_encoder=True)
print(f"  Total parameters: {sum(p.numel() for p in test_clip.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in test_clip.parameters() if p.requires_grad):,}")
del test_clip


Testing CLIP model initialization...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Clip encoder partially frozen (only training last layer and classifier head)
CLIP model initialized: openai/clip-vit-base-patch32
     - Hidden dim: 512
     - Num classes: 2
     - Encoder frozen: True
  Total parameters: 151,409,155
  Trainable parameters: 71,041,027


In [8]:
#Training function for CLIP and ViT
def train_advanced_model(model,processor,train_df,dev_df,target_task='stance',
                         num_epochs=10,
                         learning_rate=2e-5,
                         batch_size=16,
                         device=device):
    """
    Train CLIP or ViT model
    """

    print(f"Training for: {target_task.upper()}")
    
    # We create our datasets
    train_dataset = VisionDataset(train_df, IMG_PATH, processor, target_task)
    dev_dataset = VisionDataset(dev_df, IMG_PATH, processor, target_task)
    
    # We create our dataloaders with custom collate
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=NUM_WORKERS,
        collate_fn=lambda batch: custom_collate_fn(batch, processor),pin_memory=True)
    
    # We do the same for dev
    dev_loader = DataLoader(dev_dataset,batch_size=batch_size,shuffle=False,num_workers=NUM_WORKERS,
                            collate_fn=lambda batch: custom_collate_fn(batch, processor),pin_memory=True)
    
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Dev batches: {len(dev_loader)}")
    model = model.to(device)
    
    # We initialize loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY)
    
    # Learning rate scheduler with warmup
    num_training_steps = len(train_loader) * num_epochs
    num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
    
    scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=num_warmup_steps,
                                                num_training_steps=num_training_steps)
    
    # Training history
    history = {
        'train_loss': [],
        'dev_f1': [],
        'dev_loss': [],
        'learning_rates': []
    }
    
    # Early stopping
    best_f1 = 0.0
    best_model_state = None
    patience_counter = 0
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 40)
        
        # === TRAINING ===
        model.train()
        train_loss = 0.0
        
        pbar = tqdm(train_loader, desc="Training", leave=False)
        for batch in pbar:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward
            logits = model(pixel_values)
            loss = criterion(logits, labels)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item() * pixel_values.size(0)
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        train_loss = train_loss / len(train_dataset)
        
        # === VALIDATION ===
        model.eval()
        dev_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(dev_loader, desc="Validation", leave=False):
                pixel_values = batch['pixel_values'].to(device)
                labels = batch['labels'].to(device)
                
                logits = model(pixel_values)
                loss = criterion(logits, labels)
                
                dev_loss += loss.item() * pixel_values.size(0)
                preds = torch.argmax(logits, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        dev_loss = dev_loss / len(dev_dataset)
        dev_f1 = f1_score(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
        
        # Save history
        current_lr = optimizer.param_groups[0]['lr']
        history['train_loss'].append(train_loss)
        history['dev_f1'].append(dev_f1)
        history['dev_loss'].append(dev_loss)
        history['learning_rates'].append(current_lr)
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Dev Loss: {dev_loss:.4f} | Dev F1: {dev_f1:.4f}")
        print(f"LR: {current_lr:.2e}")
        
        # Early stopping
        if dev_f1 > best_f1:
            best_f1 = dev_f1
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
            print(f"   New best F1: {best_f1:.4f}")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{PATIENCE})")
            
            if patience_counter >= PATIENCE:
                print(f"\n  Early stopping at epoch {epoch + 1}")
                break
    
    # Load best model
    model.load_state_dict(best_model_state)
    model = model.to(device)
    
    print(f"\n  Training complete! Best F1: {best_f1:.4f}")
    
    return model, history

In [9]:
# Evaluate Function for CLIP and ViT
def evaluate_advanced_model(model, processor, test_df, target_task='stance', device=device):
    
    test_dataset = VisionDataset(test_df,IMG_PATH, processor, target_task)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=lambda batch: custom_collate_fn(batch, processor))
    
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing", leave=False):
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(pixel_values)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    f1 = f1_score(all_labels, all_preds, average='binary', pos_label=1)
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
    recc = recall_score(all_labels, all_preds, average='binary', pos_label=1)
    cm = confusion_matrix(all_labels, all_preds)

    
    return {
        'f1': f1,
        'accuracy': acc,
        'precision': prec,
        'recall': recc,
        'confusion_matrix': cm,
        'y_true': np.array(all_labels),
        'y_pred': np.array(all_preds)
    }


### We Initialize, Train and Evaluate CLIP Model (Stance and Persuasiveness)

In [10]:
# CLIP for stance
print("\nInitializing CLIP model for stance...")

#Initialize
clip_stance_model = CLIPClassifier(model_name=CLIP,num_classes=2,freeze_encoder=True)

# Train CLIP on stance
clip_stance_model, clip_stance_history = train_advanced_model(
    model=clip_stance_model,
    processor=clip_processor,
    train_df=df_train,
    dev_df=df_dev,
    target_task='stance',
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    device=device)

# Evaluate on test set
print("\n  Evaluating CLIP (Stance) on test set...")
clip_stance_results = evaluate_advanced_model(
    clip_stance_model, clip_processor, df_test, 'stance', device)


print("\n\n\n CLIP - STANCE TEST RESULTS")
print(f"F1 Score:  {clip_stance_results['f1']:.4f}")
print(f"Accuracy:  {clip_stance_results['accuracy']:.4f}")
print(f"Baseline (ResNet50): {baseline_results[baseline_results['Task']=='Stance']['f1'].values[0]:.4f}")
print(f"Improvement: {(clip_stance_results['f1'] - baseline_results[baseline_results['Task']=='Stance']['f1'].values[0]):.4f}")


# Save model
torch.save(clip_stance_model.state_dict(), os.path.join(OUTPUT_DIR, 'clip_stance_best.pth'))



Initializing CLIP model for stance...
Clip encoder partially frozen (only training last layer and classifier head)
CLIP model initialized: openai/clip-vit-base-patch32
     - Hidden dim: 512
     - Num classes: 2
     - Encoder frozen: True
Training for: STANCE
  Train batches: 114
  Dev batches: 13

Epoch 1/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.6734
Dev Loss: 0.6291 | Dev F1: 0.0000
LR: 2.00e-05
  No improvement (1/3)

Epoch 2/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


Train Loss: 0.5766
Dev Loss: 0.5264 | Dev F1: 0.6187
LR: 1.78e-05
   New best F1: 0.6187

Epoch 3/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.4163
Dev Loss: 0.5206 | Dev F1: 0.6711
LR: 1.56e-05
   New best F1: 0.6711

Epoch 4/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child processException ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
Exception ignored in:   File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
<function _MultiProcessingDataLoaderIter.__del__

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.2636
Dev Loss: 0.5704 | Dev F1: 0.6803
LR: 1.33e-05
   New best F1: 0.6803

Epoch 5/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


Train Loss: 0.1369
Dev Loss: 0.7433 | Dev F1: 0.6757
LR: 1.11e-05
  No improvement (1/3)

Epoch 6/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.0663
Dev Loss: 0.9250 | Dev F1: 0.6763
LR: 8.89e-06
  No improvement (2/3)

Epoch 7/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.0346
Dev Loss: 1.0858 | Dev F1: 0.6569
LR: 6.67e-06
  No improvement (3/3)

  Early stopping at epoch 7

  Training complete! Best F1: 0.6803

  Evaluating CLIP (Stance) on test set...


Testing:   0%|          | 0/19 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo




 CLIP - STANCE TEST RESULTS
F1 Score:  0.6340
Accuracy:  0.6767
Baseline (ResNet50): 0.4854
Improvement: 0.1486


In [11]:
# CLIP for Persuasiveness
print("\nInitializing CLIP model for persuasiveness...")

# Initialize new CLIP model for persuasiveness
clip_pers_model = CLIPClassifier(model_name=CLIP,num_classes=2,freeze_encoder=True)

# Train
clip_pers_model, clip_pers_history = train_advanced_model(
    model=clip_pers_model,
    processor=clip_processor,
    train_df=df_train,
    dev_df=df_dev,
    target_task='persuasiveness',
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    device=device)

# Evaluate
clip_pers_results = evaluate_advanced_model(clip_pers_model, clip_processor, df_test, 'persuasiveness', device)


print("\n\n\n CLIP - PERSUASIVENESS TEST RESULTS")
print(f"F1 Score:  {clip_pers_results['f1']:.4f}")
print(f"Accuracy:  {clip_pers_results['accuracy']:.4f}")
print(f"Baseline (ResNet50): {baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0]:.4f}")
print(f"Improvement: {(clip_pers_results['f1'] - baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0]):.4f}")


# Save
torch.save(clip_pers_model.state_dict(), os.path.join(OUTPUT_DIR, 'clip_pers_best.pth'))



Initializing CLIP model for persuasiveness...
Clip encoder partially frozen (only training last layer and classifier head)
CLIP model initialized: openai/clip-vit-base-patch32
     - Hidden dim: 512
     - Num classes: 2
     - Encoder frozen: True
Training for: PERSUASIVENESS
  Train batches: 114
  Dev batches: 13

Epoch 1/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.6554
Dev Loss: 0.5932 | Dev F1: 0.0000
LR: 2.00e-05
  No improvement (1/3)

Epoch 2/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.5359
Dev Loss: 0.5100 | Dev F1: 0.4952
LR: 1.78e-05
   New best F1: 0.4952

Epoch 3/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.4397
Dev Loss: 0.5124 | Dev F1: 0.5641
LR: 1.56e-05
   New best F1: 0.5641

Epoch 4/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>    
if w.is_alive():Traceback (most recent call last):

  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        self._shutdown_workers()
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/da

Train Loss: 0.3431
Dev Loss: 0.5575 | Dev F1: 0.5614
LR: 1.33e-05
  No improvement (1/3)

Epoch 5/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.2275
Dev Loss: 0.6486 | Dev F1: 0.5487
LR: 1.11e-05
  No improvement (2/3)

Epoch 6/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.1460
Dev Loss: 0.7986 | Dev F1: 0.4694
LR: 8.89e-06
  No improvement (3/3)

  Early stopping at epoch 6

  Training complete! Best F1: 0.5641


Testing:   0%|          | 0/19 [00:00<?, ?it/s]




 CLIP - PERSUASIVENESS TEST RESULTS
F1 Score:  0.5202
Accuracy:  0.7233
Baseline (ResNet50): 0.4114
Improvement: 0.1088


### ViT Model

Why ViT?
- State-of-the-art for pure image classification
- Attention-based (captures global context better than CNNs)
- Good baseline to compare against CLIP

In [12]:
class ViTClassifier(nn.Module):
    """
    Architecture:
      - ViT backbone (pretrained on ImageNet-21k)
      - Classification head (replaced for binary classification)
    """
    
    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=2, freeze_encoder=True):
        super(ViTClassifier, self).__init__()
        
        # Load pretrained ViT
        self.vit = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True  # Allow replacing classification head
        )

        if freeze_encoder:
            for param in self.vit.parameters():
                param.requires_grad = False

            for name, param in self.vit.named_parameters():
                if "encoder.layer.11" in name:
                    param.requires_grad = True
            print("ViT encoder partially frozen (only training last layer and classifier head)")
        else:
            print("ViT encoder not frozen (training entire model)")
        
        print(f"ViT model initialized: {model_name}")
        print(f"- Num classes: {num_classes}")
    
    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

# Test ViT
print("\nTesting ViT model initialization...")
vit_processor = ViTImageProcessor.from_pretrained(VIT)
test_vit = ViTClassifier(model_name=VIT, num_classes=2,freeze_encoder=True)
print(f"  Total parameters: {sum(p.numel() for p in test_vit.parameters()):,}")
del test_vit



Testing ViT model initialization...


IOStream.flush timed out


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViT encoder partially frozen (only training last layer and classifier head)
ViT model initialized: google/vit-base-patch16-224
- Num classes: 2
  Total parameters: 85,800,194


### We Initialize, Train and Evaluate ViT Model (Stance and Persuasiveness)

In [13]:

# We initialize ViT for stance
vit_stance_model = ViTClassifier(VIT, num_classes=2, freeze_encoder=True)

vit_stance_model, vit_stance_history = train_advanced_model(
    model=vit_stance_model,
    processor=vit_processor,
    train_df=df_train,
    dev_df=df_dev,
    target_task='stance',
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    device=device)

vit_stance_results = evaluate_advanced_model(vit_stance_model, vit_processor, df_test, 'stance', device)


print("\n\n\n ViT - STANCE TEST RESULTS")
print(f"F1 Score:  {vit_stance_results['f1']:.4f}")
print(f"Accuracy:  {vit_stance_results['accuracy']:.4f}")
print(f"Baseline (ResNet50): {baseline_results[baseline_results['Task']=='Stance']['f1'].values[0]:.4f}")
print(f"Improvement: {(vit_stance_results['f1'] - baseline_results[baseline_results['Task']=='Stance']['f1'].values[0]):.4f}")



torch.save(vit_stance_model.state_dict(), os.path.join(OUTPUT_DIR, 'vit_stance_best.pth'))




Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViT encoder partially frozen (only training last layer and classifier head)
ViT model initialized: google/vit-base-patch16-224
- Num classes: 2
Training for: STANCE
  Train batches: 114
  Dev batches: 13

Epoch 1/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.7006
Dev Loss: 0.6584 | Dev F1: 0.2936
LR: 2.00e-05
   New best F1: 0.2936

Epoch 2/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.6438
Dev Loss: 0.6264 | Dev F1: 0.4274
LR: 1.78e-05
   New best F1: 0.4274

Epoch 3/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.5982
Dev Loss: 0.6118 | Dev F1: 0.4724
LR: 1.56e-05
   New best F1: 0.4724

Epoch 4/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5659
Dev Loss: 0.6017 | Dev F1: 0.4793
LR: 1.33e-05
   New best F1: 0.4793

Epoch 5/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5406
Dev Loss: 0.5974 | Dev F1: 0.4921
LR: 1.11e-05
   New best F1: 0.4921

Epoch 6/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5212
Dev Loss: 0.5937 | Dev F1: 0.4960
LR: 8.89e-06
   New best F1: 0.4960

Epoch 7/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5052
Dev Loss: 0.5921 | Dev F1: 0.4960
LR: 6.67e-06
  No improvement (1/3)

Epoch 8/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.4933
Dev Loss: 0.5903 | Dev F1: 0.4960
LR: 4.44e-06
  No improvement (2/3)

Epoch 9/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Train Loss: 0.4850
Dev Loss: 0.5897 | Dev F1: 0.5079
LR: 2.22e-06
   New best F1: 0.5079

Epoch 10/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7a499d5900>
Traceback (most recent call last):
  File "/home/dzuniga/.conda/envs/multimodal/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/dzuniga/.conda/envs/multimo

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.4803
Dev Loss: 0.5895 | Dev F1: 0.5079
LR: 0.00e+00
  No improvement (1/3)

  Training complete! Best F1: 0.5079


Testing:   0%|          | 0/19 [00:00<?, ?it/s]




 ViT - STANCE TEST RESULTS
F1 Score:  0.5191
Accuracy:  0.6233


IndexError: index 0 is out of bounds for axis 0 with size 0

In [22]:
# We initialize ViT for persuasiveness
vit_pers_model = ViTClassifier(VIT, num_classes=2, freeze_encoder=True)

vit_pers_model, vit_pers_history = train_advanced_model(
    model=vit_pers_model,
    processor=vit_processor,
    train_df=df_train,
    dev_df=df_dev,
    target_task='persuasiveness',
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    device=device)

vit_pers_results = evaluate_advanced_model(vit_pers_model, vit_processor, df_test, 'persuasiveness', device)


print("\n\n\n ViT - PERSUASIVENESS TEST RESULTS")
print(f"F1 Score:  {vit_pers_results['f1']:.4f}")
print(f"Accuracy:  {vit_pers_results['accuracy']:.4f}")
print(f"Baseline (ResNet50): {baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0]:.4f}")
print(f"Improvement: {(vit_pers_results['f1'] - baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0]):.4f}")

torch.save(vit_stance_model.state_dict(), os.path.join(OUTPUT_DIR, 'vit_pers_best.pth'))

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViT encoder partially frozen (only training last layer and classifier head)
ViT model initialized: google/vit-base-patch16-224
- Num classes: 2
Training for: PERSUASIVENESS
  Train batches: 114
  Dev batches: 13

Epoch 1/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.7104
Dev Loss: 0.6281 | Dev F1: 0.2222
LR: 2.00e-05
   New best F1: 0.2222

Epoch 2/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5904
Dev Loss: 0.6043 | Dev F1: 0.1750
LR: 1.78e-05
  No improvement (1/3)

Epoch 3/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5536
Dev Loss: 0.5893 | Dev F1: 0.1081
LR: 1.56e-05
  No improvement (2/3)

Epoch 4/10
----------------------------------------


Training:   0%|          | 0/114 [00:00<?, ?it/s]

Validation:   0%|          | 0/13 [00:00<?, ?it/s]

Train Loss: 0.5284
Dev Loss: 0.5853 | Dev F1: 0.2195
LR: 1.33e-05
  No improvement (3/3)

  Early stopping at epoch 4

  Training complete! Best F1: 0.2222


Testing:   0%|          | 0/19 [00:00<?, ?it/s]




 ViT - PERSUASIVENESS TEST RESULTS
F1 Score:  0.1240
Accuracy:  0.6233
Baseline (ResNet50): 0.4114
Improvement: -0.2874


### We compare our models

In [None]:
# Create comparison dataframe
comparison_df = pd.DataFrame({
    'Model': ['ResNet50', 'CLIP', 'ViT'] * 2,
    'Task': ['Stance']*3 + ['Persuasiveness']*3,
    'F1 Score': [
        baseline_results[baseline_results['Task']=='Stance']['f1'].values[0],
        clip_stance_results['f1'],
        vit_stance_results['f1'],
        baseline_results[baseline_results['Task']=='Persuasiveness']['f1'].values[0],
        clip_pers_results['f1'],
        vit_pers_results['f1']
    ],
    "Accuracy": [
        baseline_results[baseline_results['Task']=='Stance']['accuracy'].values[0],
        clip_stance_results['accuracy'],
        vit_stance_results['accuracy'],
        baseline_results[baseline_results['Task']=='Persuasiveness']['accuracy'].values[0],
        clip_pers_results['accuracy'],
        vit_pers_results['accuracy']
    ],
    "Recall": [
        baseline_results[baseline_results['Task']=='Stance']['recall'].values[0],
        clip_stance_results['recall'],
        vit_stance_results['recall'],
        baseline_results[baseline_results['Task']=='Persuasiveness']['recall'].values[0],
        clip_pers_results['recall'],
        vit_pers_results['recall']
    ],
    "Precision": [
        baseline_results[baseline_results['Task']=='Stance']['precision'].values[0],
        clip_stance_results['precision'],
        vit_stance_results['precision'],
        baseline_results[baseline_results['Task']=='Persuasiveness']['precision'].values[0],
        clip_pers_results['precision'],
        vit_pers_results['precision']
    ]})

#Export Results
OUTPUT_DIR = "../../results/vision/advanced_vision_models/"
os.makedirs(OUTPUT_DIR, exist_ok=True)
comparison_df.to_csv(os.path.join(OUTPUT_DIR, 'advanced_vision_model_comparison.csv'), index=False)
print("Results exported to {OUTPUT_DIR}")