In [None]:
# ================================================
# ✅ MULTIMODAL FUSION MODEL FOR BEST F1 SCORE
# ================================================



import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, SwinForImageClassification
from torch.optim import AdamW
import torchvision.transforms as T
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import torch.nn as nn
import torch.nn.functional as F
import re
import string
import json



# ================================================
# ✅ PATHS & SETUP
# ================================================
image_dir = "/kaggle/input/basem/images"
input_csv = "/kaggle/input/basem/dataset.csv"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ================================================
# ✅ LOAD & PREPROCESS CSV
# ================================================
df = pd.read_csv(input_csv)

existing_data = []
for _, row in df.iterrows():
    image_filename = row['image_path']
    full_image_path = os.path.join(image_dir, image_filename)
    if os.path.exists(full_image_path):
        label_converted = row['label 2'] - 1
        existing_data.append({
            'Image_path': full_image_path,
            'Captions': row['extracted_text'],
            'Label_Sentiment': label_converted
        })

processed_df = pd.DataFrame(existing_data)

# ================================================
# ✅ TEXT CLEANING
# ================================================
import unicodedata
from bnlp.corpus import stopwords

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# IndicTrans2 English-to-Bangla translation model
model_name = "ai4bharat/indictrans2-en-indic-1B"
tokenizer_indic = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_indic = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
def translate_en_to_bn(text):
    # Translate only if English is detected
    if re.search(r'[a-zA-Z]', text):
        input_text = f"<2bn> {text}"
        inputs = tokenizer_indic(input_text, return_tensors="pt")
        output = model_indic.generate(**inputs, max_length=256)
        translated = tokenizer_indic.decode(output[0], skip_special_tokens=True)
        return translated
    return text

def normalize_bangla(text):
    text = unicodedata.normalize('NFC', text)
    # Add more normalization rules if needed
    return text

def remove_bangla_stopwords(text):
    words = text.split()
    filtered = [w for w in words if w not in stopwords]
    return ' '.join(filtered)

def clean_text(text):
    if pd.isna(text): return ""
    # Remove URLs and HTML
    text = re.sub(r'https?://\S+|www\.\S+', '', text)
    text = re.sub(r'<.*?>', '', text)
    # Remove punctuation
    text = text.translate(str.maketrans('', '', string.punctuation))
    # Remove extra spaces
    text = " ".join(text.split())
    # Translate English to Bangla
    text = translate_en_to_bn(text)
    # Remove irrelevant characters (keep Bangla, numbers, and spaces)
    text = re.sub(r'[^\u0980-\u09FF0-9 ]+', '', text)
    text = normalize_bangla(text)
    text = remove_bangla_stopwords(text)
    # Remove extra spaces again
    text = " ".join(text.split())
    return text

# ================================================
# ✅ DATA SPLITS
# ================================================
train_df, temp_df = train_test_split(processed_df, test_size=0.3, stratify=processed_df['Label_Sentiment'], random_state=42)
test_df, val_df = train_test_split(temp_df, test_size=1/3, stratify=temp_df['Label_Sentiment'], random_state=42)

for df_name, df_ in [('train', train_df), ('test', test_df), ('val', val_df)]:
    df_['Captions'] = df_['Captions'].astype(str).apply(clean_text)
    df_['label'] = df_['Label_Sentiment']

print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}, Test samples: {len(test_df)}")
print(f"Class distribution: {train_df['label'].value_counts().sort_index().tolist()}")

# ================================================
# ✅ LOAD MODELS
# ================================================
# Load BanglishBERT for text
bert_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglabert")
bert_model = AutoModel.from_pretrained("csebuetnlp/banglabert")

# Load Swin Transformer for images
swin_model_name = "microsoft/swin-base-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(swin_model_name)
swin_backbone = SwinForImageClassification.from_pretrained(
    swin_model_name,
    num_labels=3,
    ignore_mismatched_sizes=True
)

# ================================================
# ✅ MULTIMODAL FUSION MODEL
# ================================================
class MultimodalFusionModel(nn.Module):
    def __init__(self, bert_model, swin_model, num_classes=3, dropout_rate=0.3, fusion_dim=512):
        super().__init__()
        
        # Text encoder
        self.bert = bert_model
        self.text_dropout = nn.Dropout(dropout_rate)
        self.text_projector = nn.Linear(bert_model.config.hidden_size, fusion_dim)
        
        # Image encoder - use Swin backbone without classifier
        self.swin_backbone = swin_model.swin
        self.image_dropout = nn.Dropout(dropout_rate)
        self.image_projector = nn.Linear(swin_model.config.hidden_size, fusion_dim)
        
        # Fusion layers
        self.fusion_dropout = nn.Dropout(dropout_rate)
        self.fusion_layer1 = nn.Linear(fusion_dim * 2, fusion_dim)
        self.fusion_layer2 = nn.Linear(fusion_dim, fusion_dim // 2)
        self.batch_norm = nn.BatchNorm1d(fusion_dim // 2)
        
        # Classification head
        self.classifier = nn.Linear(fusion_dim // 2, num_classes)
        
        # Attention mechanism for fusion
        self.attention_weights = nn.Linear(fusion_dim * 2, 2)
        
    def forward(self, input_ids, attention_mask, pixel_values):
        # Text encoding
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        text_features = self.text_dropout(text_features)
        text_projected = self.text_projector(text_features)
        
        # Image encoding
        image_outputs = self.swin_backbone(pixel_values)
        image_features = image_outputs.last_hidden_state.mean(dim=1)  # Global average pooling
        image_features = self.image_dropout(image_features)
        image_projected = self.image_projector(image_features)
        
        # Concatenate features
        combined_features = torch.cat([text_projected, image_projected], dim=1)
        
        # Attention-based fusion
        attention_scores = F.softmax(self.attention_weights(combined_features), dim=1)
        text_att = attention_scores[:, 0:1]
        image_att = attention_scores[:, 1:2]
        
        # Weighted fusion
        fused_features = text_att * text_projected + image_att * image_projected
        
        # Additional fusion processing
        fusion_out = F.relu(self.fusion_layer1(combined_features))
        fusion_out = self.fusion_dropout(fusion_out)
        fusion_out = F.relu(self.fusion_layer2(fusion_out))
        fusion_out = self.batch_norm(fusion_out)
        
        # Classification
        logits = self.classifier(fusion_out)
        
        return logits

# ================================================
# ✅ MULTIMODAL DATASET
# ================================================
class MultimodalDataset(Dataset):
    def __init__(self, df, tokenizer, processor, max_length=128, is_train=False):
        self.df = df
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_length = max_length
        self.is_train = is_train
        # Define augmentations for training
        self.train_transforms = T.Compose([
            T.RandomRotation(15),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomAdjustSharpness(sharpness_factor=2),
            # You can add more or adjust parameters as needed
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Text processing
        caption = row['Captions']
        text_inputs = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        # Image processing
        image = Image.open(row['Image_path']).convert('RGB')
        if self.is_train:
            image = self.train_transforms(image)
        image_inputs = self.processor(image, return_tensors="pt")
        return {
            'input_ids': text_inputs['input_ids'].flatten(),
            'attention_mask': text_inputs['attention_mask'].flatten(),
            'pixel_values': image_inputs['pixel_values'].squeeze(0),
            'label': torch.tensor(row['label'], dtype=torch.long)
        }

# ================================================
# ✅ DATALOADERS
# ================================================
batch_size = 8

train_dataset = MultimodalDataset(train_df, bert_tokenizer, image_processor, is_train=True)
val_dataset = MultimodalDataset(val_df, bert_tokenizer, image_processor, is_train=False)
test_dataset = MultimodalDataset(test_df, bert_tokenizer, image_processor, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# ================================================
# ✅ INITIALIZE MODEL
# ================================================
model = MultimodalFusionModel(bert_model, swin_backbone, num_classes=3, dropout_rate=0.3).to(device)

# ================================================
# ✅ LOSS & OPTIMIZER WITH ADVANCED TECHNIQUES
# ================================================
# Focal Loss for handling class imbalance
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Calculate class weights
class_counts = train_df['label'].value_counts().sort_index().tolist()
total_samples = sum(class_counts)
class_weights = [total_samples / count for count in class_counts]
alpha = torch.FloatTensor(class_weights).to(device)

# Use Focal Loss for better handling of class imbalance
criterion = FocalLoss(alpha=alpha, gamma=2.0)

# Optimizer with different learning rates for different parts
text_params = list(model.bert.parameters())
image_params = list(model.swin_backbone.parameters())
fusion_params = list(model.text_projector.parameters()) + list(model.image_projector.parameters()) + \
               list(model.fusion_layer1.parameters()) + list(model.fusion_layer2.parameters()) + \
               list(model.classifier.parameters()) + list(model.attention_weights.parameters())

optimizer = AdamW([
    {'params': text_params, 'lr': 2e-5},
    {'params': image_params, 'lr': 1e-5},
    {'params': fusion_params, 'lr': 5e-4}
], weight_decay=0.01)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)

# ================================================
# ✅ TRAINING LOOP WITH ADVANCED TECHNIQUES
# ================================================
num_epochs = 25
patience = 5
patience_counter = 0
best_val_f1 = 0.0

print("🚀 Starting Multimodal Fusion Training...")

for epoch in range(num_epochs):
    # ============================================================
    # TRAINING PHASE
    # ============================================================
    model.train()
    total_train_loss = 0
    train_predictions = []
    train_labels = []

    for batch in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask, pixel_values)
        loss = criterion(logits, labels)
        
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_train_loss += loss.item()
        
        predictions = torch.argmax(logits, dim=1)
        train_predictions.extend(predictions.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    avg_train_loss = total_train_loss / len(train_loader)
    train_accuracy = accuracy_score(train_labels, train_predictions)
    train_f1 = precision_recall_fscore_support(train_labels, train_predictions, average='weighted')[2]

    # ============================================================
    # VALIDATION PHASE
    # ============================================================
    model.eval()
    total_val_loss = 0
    val_predictions = []
    val_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids, attention_mask, pixel_values)
            loss = criterion(logits, labels)
            
            total_val_loss += loss.item()
            
            predictions = torch.argmax(logits, dim=1)
            val_predictions.extend(predictions.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = accuracy_score(val_labels, val_predictions)
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='weighted')
    
    # Step scheduler
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.4f} | Train F1: {train_f1:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1: {val_f1:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")

    # ============================================================
    # EARLY STOPPING BASED ON F1 SCORE
    # ============================================================
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), "best_multimodal_model.pt")
        print(f"✅ Validation F1 improved to {val_f1:.4f} — model saved.")
    else:
        patience_counter += 1
        print(f"⏰ No improvement — patience {patience_counter}/{patience}")

        if patience_counter >= patience:
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break
    print("-" * 70)

# ================================================
# ✅ FINAL TEST EVALUATION
# ================================================
print("\n🔍 Loading best model for final evaluation...")
model.load_state_dict(torch.load("best_multimodal_model.pt"))
model.eval()

test_predictions = []
test_labels = []
total_test_loss = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Final Test Evaluation"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        logits = model(input_ids, attention_mask, pixel_values)
        loss = criterion(logits, labels)
        
        total_test_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        test_predictions.extend(predictions.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

# Calculate comprehensive metrics
test_accuracy = accuracy_score(test_labels, test_predictions)
test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='weighted')
test_precision_macro, test_recall_macro, test_f1_macro, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro')
cm = confusion_matrix(test_labels, test_predictions)

# Per-class metrics
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
    test_labels, test_predictions, average=None
)

print("\n" + "="*70)
print("🎯 FINAL MULTIMODAL FUSION TEST RESULTS")
print("="*70)
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1-Score (Weighted): {test_f1:.4f}")
print(f"Test F1-Score (Macro): {test_f1_macro:.4f}")
print(f"Test Precision (Weighted): {test_precision:.4f}")
print(f"Test Recall (Weighted): {test_recall:.4f}")
print(f"Test Loss: {total_test_loss/len(test_loader):.4f}")

print("\n📈 Per-Class Metrics:")
class_names = ['Negative', 'Neutral', 'Positive']
for i, class_name in enumerate(class_names):
    print(f"{class_name:>8}: Precision={precision_per_class[i]:.4f}, Recall={recall_per_class[i]:.4f}, F1={f1_per_class[i]:.4f}, Support={support[i]}")

print(f"\n🎯 Confusion Matrix:")
print(f"{'':>10} {'Neg':>6} {'Neu':>6} {'Pos':>6}")
for i, class_name in enumerate(['Negative', 'Neutral', 'Positive']):
    print(f"{class_name:>10} {cm[i][0]:>6} {cm[i][1]:>6} {cm[i][2]:>6}")

print("\n📋 Detailed Classification Report:")
print(classification_report(test_labels, test_predictions, target_names=class_names))

# ================================================
# ✅ SAVE RESULTS
# ================================================
results = {
    'test_accuracy': test_accuracy,
    'test_f1_weighted': test_f1,
    'test_f1_macro': test_f1_macro,
    'test_precision_weighted': test_precision,
    'test_recall_weighted': test_recall,
    'test_loss': total_test_loss/len(test_loader),
    'confusion_matrix': cm.tolist(),
    'per_class_metrics': {
        'precision': precision_per_class.tolist(),
        'recall': recall_per_class.tolist(),
        'f1': f1_per_class.tolist(),
        'support': support.tolist()
    }
}

with open('/kaggle/working/multimodal_fusion_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("\n" + "="*70)
print("✅ MULTIMODAL FUSION MODEL TRAINING COMPLETE!")
print(f"🏆 Best F1 Score Achieved: {test_f1:.4f}")
print("📁 Results saved to 'multimodal_fusion_results.json'")
print("="*70)