In [None]:
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm

from DL_vs_HateSpeech.models.model_v0 import ModelV0
from DL_vs_HateSpeech.loading_data.dataloader import DataLoader


# Initialize device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Hyperparameters ---
BATCH_SIZE = 32
LR = 1e-5
EPOCHS = 10

def collate_fn(batch):
    """
    Custom collate function to handle a batch of (image, text, label)
    where image is a PIL.Image and text is a string.
    """
    images, texts, labels = zip(*batch)  # Unzip list of tuples
    labels = torch.tensor(labels, dtype=torch.float32)  # Convert to tensor
    return list(images), list(texts), labels

# --- Initialize Datasets and DataLoaders ---
train_dataset = DataLoader(type="train")
val_dataset = DataLoader(type="val")

train_loader = TorchDataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
val_loader = TorchDataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)

# --- Initialize Model ---
model = ModelV0(clip_model_type="32").to(device)
optimizer = AdamW(model.parameters(), lr=LR)
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss

# --- Training Function ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    
    for images, texts, labels in tqdm(dataloader, desc="Training"):
        # Move ALL tensors to device
        # images = images.to(device)
        labels = labels.float().to(device)
        
        # Forward pass
        optimizer.zero_grad()
        probs = model(texts, images)  # texts remain as list
        loss = criterion(probs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# --- Validation Function ---
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, texts, labels in tqdm(dataloader, desc="Evaluating"):
            # images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            probs = model(texts, images)
            preds = (probs > 0.5).int()  # Threshold at 0.5

            # Calculate accuracy
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total  # Accuracy

# --- Training Loop ---
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Evaluate
    val_accuracy = evaluate(model, val_loader, device)
    print(f"Val Accuracy: {val_accuracy * 100:.2f}%")

  from .autonotebook import tqdm as notebook_tqdm



Epoch 1/10


Training: 100%|██████████| 14/14 [00:47<00:00,  3.41s/it]


Train Loss: 0.6838


Evaluating: 100%|██████████| 5/5 [00:04<00:00,  1.10it/s]


Val Accuracy: 61.07%

Epoch 2/10


Training: 100%|██████████| 14/14 [00:49<00:00,  3.56s/it]


Train Loss: 0.6206


Evaluating: 100%|██████████| 5/5 [00:04<00:00,  1.07it/s]


Val Accuracy: 64.43%

Epoch 3/10


Training: 100%|██████████| 14/14 [00:53<00:00,  3.79s/it]


Train Loss: 0.5452


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.04s/it]


Val Accuracy: 64.43%

Epoch 4/10


Training: 100%|██████████| 14/14 [00:52<00:00,  3.76s/it]


Train Loss: 0.4145


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Val Accuracy: 65.10%

Epoch 5/10


Training: 100%|██████████| 14/14 [00:54<00:00,  3.87s/it]


Train Loss: 0.2229


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Val Accuracy: 62.42%

Epoch 6/10


Training: 100%|██████████| 14/14 [00:53<00:00,  3.80s/it]


Train Loss: 0.0916


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.04s/it]


Val Accuracy: 57.05%

Epoch 7/10


Training: 100%|██████████| 14/14 [00:54<00:00,  3.91s/it]


Train Loss: 0.0530


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.03s/it]


Val Accuracy: 63.09%

Epoch 8/10


Training: 100%|██████████| 14/14 [00:55<00:00,  3.94s/it]


Train Loss: 0.0267


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.06s/it]


Val Accuracy: 64.43%

Epoch 9/10


Training: 100%|██████████| 14/14 [00:52<00:00,  3.79s/it]


Train Loss: 0.0159


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.10s/it]


Val Accuracy: 63.09%

Epoch 10/10


Training: 100%|██████████| 14/14 [00:54<00:00,  3.89s/it]


Train Loss: 0.0110


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.08s/it]


Val Accuracy: 63.76%
