# 💊 Demo: Hệ thống Nhận dạng Viên Thuốc Multimodal với Transformer

## Tổng quan
Notebook này minh họa quy trình hoàn chỉnh để xây dựng và đánh giá hệ thống nhận dạng viên thuốc sử dụng:

- **Multimodal Transformer**: Kết hợp thông tin từ hình ảnh và text imprint
- **Cross-modal Attention**: Học representation chung cho visual và textual features  
- **Apache Spark**: Xử lý dữ liệu lớn
- **GPU Acceleration**: Rapids cuDF/cuML để tăng tốc

## Mục tiêu
1. Xử lý dữ liệu multimodal (hình ảnh + text)
2. Xây dựng và training Multimodal Transformer
3. Đánh giá hiệu suất model
4. Demo inference trên dữ liệu thực tế

## 1. Import Required Libraries
Import các thư viện cần thiết cho xử lý dữ liệu, machine learning và visualization.

In [None]:
# Core libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from transformers import AutoModel, AutoTokenizer
import timm

# Spark và Big Data
try:
    import findspark
    findspark.init()
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, udf
    from pyspark.sql.types import StringType, BinaryType, ArrayType, FloatType
    SPARK_AVAILABLE = True
    print("✅ Apache Spark available")
except ImportError:
    SPARK_AVAILABLE = False
    print("⚠️ Apache Spark not available")

# Rapids for GPU acceleration
try:
    import cudf
    import cupy as cp
    RAPIDS_AVAILABLE = True
    print("✅ Rapids CUDF/CuPy available")
except ImportError:
    RAPIDS_AVAILABLE = False
    print("⚠️ Rapids not available, using CPU")

# Plotting
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Custom modules
sys.path.append('../src')
from models.multimodal_transformer import MultimodalPillTransformer
from data.data_processing import SparkDataProcessor, PillDataset, get_data_transforms
from utils.metrics import MetricsCalculator
from utils.utils import set_seed, get_device

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("📦 All libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🖥️ Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 2. Load and Inspect Data
Thiết lập cấu hình và tải dữ liệu mẫu để demo hệ thống.

In [None]:
# Setup
set_seed(42)
device = get_device()

# Configuration
config = {
    "model": {
        "visual_encoder": {
            "type": "vit",
            "model_name": "vit_base_patch16_224",
            "pretrained": True,
            "output_dim": 768
        },
        "text_encoder": {
            "type": "bert", 
            "model_name": "bert-base-uncased",
            "pretrained": True,
            "output_dim": 768,
            "max_length": 128
        },
        "fusion": {
            "type": "cross_attention",
            "hidden_dim": 512,
            "num_attention_heads": 8,
            "dropout": 0.1
        },
        "classifier": {
            "num_classes": 100,
            "hidden_dims": [512, 256],
            "dropout": 0.3
        }
    },
    "data": {
        "image_size": 224,
        "spark": {
            "app_name": "PillRecognitionDemo",
            "master": "local[2]",
            "executor_memory": "2g",
            "driver_memory": "1g"
        }
    }
}

print("⚙️ Configuration loaded:")
print(f"  - Visual Encoder: {config['model']['visual_encoder']['model_name']}")
print(f"  - Text Encoder: {config['model']['text_encoder']['model_name']}")
print(f"  - Fusion Type: {config['model']['fusion']['type']}")
print(f"  - Number of Classes: {config['model']['classifier']['num_classes']}")

In [None]:
# Tạo dữ liệu mẫu cho demo
def create_sample_data(num_samples=500):
    """Tạo dữ liệu mẫu với hình ảnh và text imprint"""
    
    # Danh sách các loại thuốc mẫu
    pill_classes = [
        "Acetaminophen 500mg", "Ibuprofen 200mg", "Aspirin 325mg",
        "Metformin 500mg", "Lisinopril 10mg", "Atorvastatin 20mg",
        "Amlodipine 5mg", "Omeprazole 20mg", "Levothyroxine 50mcg",
        "Simvastatin 40mg", "Losartan 50mg", "Gabapentin 300mg"
    ]
    
    data = []
    
    for i in range(num_samples):
        # Random pill class
        pill_class = pill_classes[i % len(pill_classes)]
        class_id = i % len(pill_classes)
        
        # Generate synthetic image (224x224x3)
        np.random.seed(i)
        image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
        
        # Add some pattern to make it look more realistic
        center = (112, 112)
        radius = np.random.randint(40, 80)
        color = tuple(np.random.randint(100, 255, 3).tolist())
        cv2.circle(image, center, radius, color, -1)
        
        # Generate text imprint patterns
        text_patterns = [
            f"PILL{i:03d}",
            f"{class_id}MG",
            f"RX{i:02d}",
            f"TAB{class_id}",
            f"MED{i%100:02d}"
        ]
        text_imprint = text_patterns[i % len(text_patterns)]
        
        data.append({
            "image_id": f"img_{i:05d}",
            "image_array": image,
            "text_imprint": text_imprint,
            "pill_class": pill_class,
            "class_id": class_id,
            "manufacturer": f"Company_{chr(65 + i%5)}",  # Company_A, B, C, D, E
            "dosage": pill_class.split()[-1] if "mg" in pill_class else "Unknown"
        })
    
    return pd.DataFrame(data)

# Tạo dataset mẫu
print("🔄 Generating sample dataset...")
sample_df = create_sample_data(500)

print(f"✅ Created dataset with {len(sample_df)} samples")
print(f"📊 Number of unique classes: {sample_df['class_id'].nunique()}")
print(f"📝 Sample text imprints: {sample_df['text_imprint'].unique()[:10]}")

# Hiển thị thông tin cơ bản
print("\n📋 Dataset Info:")
print(sample_df.info())
print("\n🔍 First 5 rows:")
sample_df[['image_id', 'text_imprint', 'pill_class', 'manufacturer']].head()

## 3. Data Preprocessing
Tiền xử lý dữ liệu hình ảnh và text để chuẩn bị cho model training.

In [None]:
# Hiển thị một số mẫu hình ảnh
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle("📸 Sample Pill Images", fontsize=16, fontweight='bold')

for i in range(10):
    row = i // 5
    col = i % 5
    
    # Get image and info
    image = sample_df.iloc[i]['image_array']
    text = sample_df.iloc[i]['text_imprint']
    pill_class = sample_df.iloc[i]['pill_class']
    
    # Display image
    axes[row, col].imshow(image)
    axes[row, col].set_title(f"{text}\n{pill_class[:15]}...", fontsize=8)
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

# Phân bố dữ liệu
print("\n📊 Class Distribution:")
class_counts = sample_df['pill_class'].value_counts()
print(class_counts)

# Visualization
fig = px.bar(
    x=class_counts.index,
    y=class_counts.values,
    title="Phân bố các lớp thuốc trong dataset",
    labels={'x': 'Loại thuốc', 'y': 'Số lượng mẫu'}
)
fig.update_xaxis(tickangle=45)
fig.show()

# Text imprint analysis
print("\n📝 Text Imprint Analysis:")
text_lengths = sample_df['text_imprint'].str.len()
print(f"Text length - Min: {text_lengths.min()}, Max: {text_lengths.max()}, Mean: {text_lengths.mean():.2f}")

fig = px.histogram(
    x=text_lengths,
    title="Phân bố độ dài Text Imprint",
    labels={'x': 'Độ dài text', 'y': 'Số lượng'}
)
fig.show()

## 4. Feature Engineering
Tạo và xử lý features cho cả hình ảnh và text data.

In [None]:
# Data transformations cho training
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Split data
from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(sample_df, test_size=0.3, random_state=42, stratify=sample_df['class_id'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['class_id'])

print(f"📊 Data Split:")
print(f"  - Training: {len(train_df)} samples")
print(f"  - Validation: {len(val_df)} samples") 
print(f"  - Test: {len(test_df)} samples")

# Custom Dataset class
class DemoPillDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Image
        image = row['image_array']
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            
        return {
            'image': image,
            'text': row['text_imprint'],
            'label': row['class_id'],
            'class_name': row['pill_class']
        }

# Create datasets
train_dataset = DemoPillDataset(train_df, train_transform)
val_dataset = DemoPillDataset(val_df, val_transform)
test_dataset = DemoPillDataset(test_df, val_transform)

print("✅ Datasets created successfully!")

# Show sample after transformation
sample_batch = train_dataset[0]
print(f"\n🔍 Sample after transformation:")
print(f"  - Image shape: {sample_batch['image'].shape}")
print(f"  - Text: '{sample_batch['text']}'")
print(f"  - Label: {sample_batch['label']}")
print(f"  - Class: {sample_batch['class_name']}")

## 5. Model Training
Khởi tạo và training Multimodal Transformer model.

In [None]:
# Initialize model
print("🤖 Initializing Multimodal Transformer...")
model = MultimodalPillTransformer(config["model"]).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"📊 Model Statistics:")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")

# Setup training
def collate_fn(batch):
    """Custom collate function for multimodal data"""
    images = torch.stack([item['image'] for item in batch])
    texts = [item['text'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    class_names = [item['class_name'] for item in batch]
    
    return {
        'images': images,
        'texts': texts,
        'labels': labels,
        'class_names': class_names
    }

# DataLoaders
batch_size = 16
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)

# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"✅ Training setup complete!")
print(f"  - Batch size: {batch_size}")
print(f"  - Train batches: {len(train_loader)}")
print(f"  - Val batches: {len(val_loader)}")

# Training function (simplified for demo)
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    tokenizer = model.get_text_tokenizer()
    
    for batch in dataloader:
        images = batch['images'].to(device)
        texts = batch['texts']
        labels = batch['labels'].to(device)
        
        # Tokenize texts
        text_inputs = tokenizer(
            texts,
            max_length=128,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, text_inputs)
        loss = criterion(outputs["logits"], labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        _, predicted = torch.max(outputs["logits"], 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return total_loss / len(dataloader), 100 * correct / total

# Quick training demo (just a few epochs)
print("🚀 Starting training demo (3 epochs)...")
train_losses = []
train_accs = []

for epoch in range(3):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    print(f"Epoch {epoch+1}/3: Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")

print("✅ Training demo completed!")

## 6. Model Evaluation
Đánh giá hiệu suất model trên validation set và phân tích kết quả.

In [None]:
# Evaluation function
def evaluate_model(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    all_features = {'visual': [], 'text': [], 'fused': []}
    
    tokenizer = model.get_text_tokenizer()
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['images'].to(device)
            texts = batch['texts']
            labels = batch['labels'].to(device)
            
            # Tokenize texts
            text_inputs = tokenizer(
                texts,
                max_length=128,
                padding=True,
                truncation=True,
                return_tensors="pt"
            ).to(device)
            
            # Forward pass
            outputs = model(images, text_inputs, return_features=True)
            
            # Get predictions
            predictions = torch.argmax(outputs["logits"], dim=1)
            
            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Store features for analysis
            all_features['visual'].append(outputs["visual_features"].cpu())
            all_features['text'].append(outputs["text_features"].cpu())
            all_features['fused'].append(outputs["fused_features"].cpu())
    
    # Concatenate features
    for key in all_features:
        all_features[key] = torch.cat(all_features[key], dim=0)
    
    return all_predictions, all_labels, all_features

# Evaluate on validation set
print("🔍 Evaluating model on validation set...")
val_predictions, val_labels, val_features = evaluate_model(model, val_loader, device)

# Calculate metrics
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

val_accuracy = accuracy_score(val_labels, val_predictions)
print(f"📊 Validation Accuracy: {val_accuracy:.4f}")

# Classification report
class_names = train_df['pill_class'].unique()[:len(set(val_labels))]
report = classification_report(val_labels, val_predictions, target_names=class_names, output_dict=True)

print("\n📋 Classification Report:")
print(classification_report(val_labels, val_predictions, target_names=class_names))

# Confusion Matrix
cm = confusion_matrix(val_labels, val_predictions)
fig = px.imshow(
    cm,
    title="Confusion Matrix",
    labels=dict(x="Predicted", y="Actual"),
    color_continuous_scale="Blues"
)
fig.show()

# Feature analysis
print("\n🔍 Feature Analysis:")
visual_magnitude = torch.norm(val_features['visual'], dim=1).mean().item()
text_magnitude = torch.norm(val_features['text'], dim=1).mean().item()
fused_magnitude = torch.norm(val_features['fused'], dim=1).mean().item()

print(f"  - Visual features magnitude: {visual_magnitude:.4f}")
print(f"  - Text features magnitude: {text_magnitude:.4f}")
print(f"  - Fused features magnitude: {fused_magnitude:.4f}")

# Cross-modal similarity
visual_norm = torch.nn.functional.normalize(val_features['visual'], dim=1)
text_norm = torch.nn.functional.normalize(val_features['text'], dim=1)
similarity = torch.mean(torch.sum(visual_norm * text_norm, dim=1)).item()
print(f"  - Visual-Text similarity: {similarity:.4f}")

# Training progress visualization
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Training Loss', 'Training Accuracy')
)

fig.add_trace(
    go.Scatter(x=list(range(1, len(train_losses)+1)), y=train_losses, name='Loss'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(x=list(range(1, len(train_accs)+1)), y=train_accs, name='Accuracy'),
    row=1, col=2
)

fig.update_layout(title_text="Training Progress", showlegend=False)
fig.show()

## 7. Make Predictions
Demo inference trên dữ liệu mới và phân tích kết quả dự đoán.

In [None]:
# Inference function
def predict_pill(model, image, text_imprint, device, top_k=5):
    """Make prediction on a single pill image and text"""
    model.eval()
    
    tokenizer = model.get_text_tokenizer()
    
    # Preprocess image
    if isinstance(image, np.ndarray):
        image = val_transform(image)
    image = image.unsqueeze(0).to(device)
    
    # Tokenize text
    text_inputs = tokenizer(
        [text_imprint],
        max_length=128,
        padding=True,
        truncation=True,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model(image, text_inputs, return_features=True)
        probs = torch.softmax(outputs["logits"], dim=1)
        top_probs, top_indices = torch.topk(probs, k=top_k, dim=1)
    
    results = {
        'top_classes': top_indices[0].cpu().numpy(),
        'top_probabilities': top_probs[0].cpu().numpy(),
        'features': {
            'visual': outputs["visual_features"].cpu(),
            'text': outputs["text_features"].cpu(),
            'fused': outputs["fused_features"].cpu()
        }
    }
    
    return results

# Demo predictions on test samples
print("🔮 Making predictions on test samples...")

# Select a few test samples
test_indices = [0, 5, 10, 15, 20]
fig, axes = plt.subplots(1, len(test_indices), figsize=(20, 4))

for i, idx in enumerate(test_indices):
    # Get test sample
    sample = test_dataset[idx]
    image = sample['image']
    text = sample['text']
    true_label = sample['label']
    true_class = sample['class_name']
    
    # Make prediction
    results = predict_pill(model, image, text, device)
    
    # Get original image for display
    original_image = test_df.iloc[idx]['image_array']
    
    # Display
    axes[i].imshow(original_image)
    axes[i].set_title(
        f"Text: {text}\\n"
        f"True: {true_class[:15]}...\\n"
        f"Pred: Class {results['top_classes'][0]}\\n"
        f"Conf: {results['top_probabilities'][0]:.3f}",
        fontsize=10
    )
    axes[i].axis('off')
    
    print(f"\\n📋 Sample {idx}:")
    print(f"  Text: '{text}'")
    print(f"  True class: {true_class}")
    print(f"  Top 3 predictions:")
    for j in range(3):
        class_id = results['top_classes'][j]
        prob = results['top_probabilities'][j]
        print(f"    {j+1}. Class {class_id}: {prob:.3f}")

plt.suptitle("🔮 Inference Results on Test Samples", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Feature visualization for multimodal analysis
print("\\n🎨 Multimodal Feature Analysis:")

# Get features from a batch of test samples
test_loader_single = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn
)

# Get one batch for analysis
test_batch = next(iter(test_loader_single))
test_predictions, test_labels_batch, test_features_batch = evaluate_model(
    model, [test_batch], device
)

# Plot feature distributions
fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=('Visual Features', 'Text Features', 'Fused Features')
)

for i, (feature_type, features) in enumerate(test_features_batch.items()):
    feature_norms = torch.norm(features, dim=1).numpy()
    
    fig.add_trace(
        go.Histogram(x=feature_norms, name=f'{feature_type.title()} Features'),
        row=1, col=i+1
    )

fig.update_layout(
    title_text="Feature Magnitude Distributions",
    showlegend=False
)
fig.show()

print("✅ Inference demo completed!")
print("\\n🎉 Notebook demo hoàn thành! Hệ thống multimodal pill recognition đã được demo thành công.")