# Custom CNN + Text Encoder Training for Image-Text Retrieval (Fashion MNIST)

This notebook trains a dual-encoder model on Fashion MNIST dataset:
- **Image Encoder**: Custom CNN (handles grayscale images)
- **Text Encoder**: LSTM-based text encoder
- **Dataset**: Fashion MNIST (70,000 images, 10 fashion categories)
- **Objective**: Project images and text to shared embedding space
- **Optimization**: Intel CPU optimized with OpenVINO

## Features:
- Contrastive learning (CLIP-style)
- Supports both image and text queries
- Intel CPU optimized
- FAISS-based similarity search
- Grayscale image support (Fashion MNIST)

## 1. Install Dependencies

In [None]:
# Install required packages
! pip install torch torchvision
! pip install openvino openvino-dev
! pip install faiss-cpu
! pip install pillow numpy matplotlib seaborn pandas scikit-learn tqdm

## 2. Import Libraries

In [None]:
import os
import sys
import random
import time
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Clear any cached text_descriptions module
if 'text_descriptions' in sys.modules:
    del sys.modules['text_descriptions']

# Add Fashion MNIST path at the beginning of sys.path
fashion_mnist_path = r'e:\Projects\AI Based\RecTrio\datasets\fashion_mnist'
if fashion_mnist_path in sys.path:
    sys.path.remove(fashion_mnist_path)
sys.path.insert(0, fashion_mnist_path)

# Now import from the correct location
from text_descriptions import FASHION_DESCRIPTIONS, CLASSES

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ Available classes: {CLASSES}")
print(f"✓ Number of classes: {len(CLASSES)}")
print(f"✓ Sample class descriptions: {len(FASHION_DESCRIPTIONS[CLASSES[0]])} per class")

## 3. Configuration

In [None]:
# Paths
DATASET_PATH = Path(r'e:\Projects\AI Based\RecTrio\datasets\fashion_mnist\processed\train')
OUTPUT_DIR = Path(r'e:\Projects\AI Based\RecTrio\V1\models\fashion_cnn')
VECTOR_DB_DIR = OUTPUT_DIR / 'vector_db'

# Create directories if they don't exist
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
VECTOR_DB_DIR.mkdir(parents=True, exist_ok=True)

# Model Configuration
IMAGE_SIZE = 224
INPUT_CHANNELS = 1  # Grayscale images (Fashion MNIST)
EMBEDDING_DIM = 256  # Shared embedding space dimension
CNN_FEATURES = 512   # CNN output features
TEXT_HIDDEN_DIM = 256  # LSTM hidden dimension
MAX_TEXT_LENGTH = 20  # Maximum words in description

# Training Configuration
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
TEMPERATURE = 0.07  # Contrastive loss temperature

# Data split
TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.2

# Random seed
RANDOM_SEED = 42

# Device
DEVICE = 'cpu'  # Intel CPU optimized
print(f"✓ Dataset: {DATASET_PATH}")
print(f"✓ Output directory: {OUTPUT_DIR}")
print(f"✓ Vector DB directory: {VECTOR_DB_DIR}")
print(f"✓ Device: {DEVICE}")
print(f"✓ Input channels: {INPUT_CHANNELS} (grayscale)")
print(f"✓ Embedding dimension: {EMBEDDING_DIM}")

## 4. Set Random Seed

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)
print(f"✓ Random seed set to {RANDOM_SEED}")

## 5. Build Vocabulary from Text Descriptions

In [None]:
class Vocabulary:
    """Simple vocabulary for text encoding"""
    
    def __init__(self):
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = {0: '<PAD>', 1: '<UNK>'}
        self.word_count = {}
        self.n_words = 2
    
    def add_sentence(self, sentence):
        for word in sentence.lower().split():
            self.add_word(word)
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.n_words
            self.idx2word[self.n_words] = word
            self.word_count[word] = 1
            self.n_words += 1
        else:
            self.word_count[word] += 1
    
    def encode(self, sentence, max_length=20):
        """Encode sentence to indices"""
        words = sentence.lower().split()[:max_length]
        indices = [self.word2idx.get(word, 1) for word in words]
        # Pad to max_length
        indices += [0] * (max_length - len(indices))
        return indices

# Build vocabulary
vocab = Vocabulary()
for class_name, descriptions in FASHION_DESCRIPTIONS.items():
    for desc in descriptions:
        vocab.add_sentence(desc)

print(f"✓ Vocabulary size: {vocab.n_words}")
print(f"  Sample words: {list(vocab.word2idx.keys())[:20]}")

## 6. Dataset Class

In [None]:
class ImageTextDataset(Dataset):
    """Dataset that pairs Fashion MNIST images with text descriptions"""
    
    def __init__(self, root_dir, vocab, transform=None, split='train'):
        self.root_dir = Path(root_dir)
        self.vocab = vocab
        self.transform = transform
        self.split = split
        
        # Collect all images
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {cls: idx for idx, cls in enumerate(CLASSES)}
        
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
        
        for class_name in CLASSES:
            class_dir = self.root_dir / class_name
            if not class_dir.exists():
                print(f"Warning: {class_dir} not found, skipping...")
                continue
            
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in valid_extensions:
                    self.image_paths.append(str(img_path))
                    self.labels.append(class_name)
        
        print(f"✓ {split} dataset: {len(self.image_paths)} images across {len(CLASSES)} classes")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image (Fashion MNIST is grayscale)
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        
        # Convert grayscale to 3-channel for compatibility
        image = image.convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Get class and random description
        class_name = self.labels[idx]
        descriptions = FASHION_DESCRIPTIONS[class_name]
        description = random.choice(descriptions)
        
        # Encode text
        text_indices = self.vocab.encode(description, MAX_TEXT_LENGTH)
        text_tensor = torch.LongTensor(text_indices)
        
        # Class index
        class_idx = self.class_to_idx[class_name]
        
        return image, text_tensor, class_idx

print("✓ Dataset class defined")

## 7. Model Architecture

In [None]:
class ImageEncoder(nn.Module):
    """Custom CNN for image encoding"""
    
    def __init__(self, embedding_dim=256):
        super().__init__()
        
        # Convolutional layers
        self.conv_layers = nn.Sequential(
            # Block 1: 224x224 -> 112x112
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 2: 112x112 -> 56x56
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 3: 56x56 -> 28x28
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 4: 28x28 -> 14x14
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Projection to embedding space
        self.projection = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, embedding_dim)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.projection(x)
        # L2 normalize
        x = F.normalize(x, p=2, dim=1)
        return x


class TextEncoder(nn.Module):
    """LSTM-based text encoder"""
    
    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=256):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True)
        
        # Projection to embedding space
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim * 2, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, embedding_dim)
        )
    
    def forward(self, x):
        # x: [batch, max_length]
        embedded = self.embedding(x)  # [batch, max_length, 128]
        lstm_out, (hidden, _) = self.lstm(embedded)
        
        # Use final hidden state from both directions
        hidden = torch.cat([hidden[0], hidden[1]], dim=1)  # [batch, hidden_dim*2]
        
        output = self.projection(hidden)
        # L2 normalize
        output = F.normalize(output, p=2, dim=1)
        return output


class DualEncoder(nn.Module):
    """Combined image and text encoder"""
    
    def __init__(self, vocab_size, embedding_dim=256, text_hidden_dim=256):
        super().__init__()
        self.image_encoder = ImageEncoder(embedding_dim)
        self.text_encoder = TextEncoder(vocab_size, embedding_dim, text_hidden_dim)
    
    def forward(self, images, texts):
        image_embeddings = self.image_encoder(images)
        text_embeddings = self.text_encoder(texts)
        return image_embeddings, text_embeddings

print("✓ Model architecture defined")

## 8. Contrastive Loss (CLIP-style)

In [None]:
class ContrastiveLoss(nn.Module):
    """InfoNCE / CLIP-style contrastive loss"""
    
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, image_embeddings, text_embeddings):
        # Normalize embeddings
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
        
        # Compute similarity matrix
        logits = torch.matmul(image_embeddings, text_embeddings.t()) / self.temperature
        
        batch_size = image_embeddings.shape[0]
        labels = torch.arange(batch_size, device=image_embeddings.device)
        
        # Symmetric loss (image-to-text + text-to-image)
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)
        
        loss = (loss_i2t + loss_t2i) / 2
        
        # Compute accuracy
        with torch.no_grad():
            preds_i2t = torch.argmax(logits, dim=1)
            preds_t2i = torch.argmax(logits.t(), dim=1)
            acc_i2t = (preds_i2t == labels).float().mean()
            acc_t2i = (preds_t2i == labels).float().mean()
            accuracy = (acc_i2t + acc_t2i) / 2
        
        return loss, accuracy

print("✓ Contrastive loss defined")

## 9. Create DataLoaders

In [None]:
# Image transformations
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
full_dataset = ImageTextDataset(DATASET_PATH, vocab, transform=train_transform, split='full')

# Split into train and val
dataset_size = len(full_dataset)
indices = list(range(dataset_size))
random.shuffle(indices)

train_size = int(TRAIN_SPLIT * dataset_size)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_dataset, val_indices)

# Update val_dataset transform
val_dataset.dataset.transform = val_transform

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

## 10. Initialize Model and Training Components

In [None]:
# Create model
model = DualEncoder(
    vocab_size=vocab.n_words,
    embedding_dim=EMBEDDING_DIM,
    text_hidden_dim=TEXT_HIDDEN_DIM
).to(DEVICE)

# Loss function
criterion = ContrastiveLoss(temperature=TEMPERATURE)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model initialized on {DEVICE}")

## 11. Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_accuracy = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, texts, _ in pbar:
        images = images.to(device)
        texts = texts.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        image_emb, text_emb = model(images, texts)
        loss, accuracy = criterion(image_emb, text_emb)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_accuracy += accuracy.item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy.item():.4f}'})
    
    return total_loss / len(dataloader), total_accuracy / len(dataloader)


@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    
    pbar = tqdm(dataloader, desc='Validation')
    for images, texts, _ in pbar:
        images = images.to(device)
        texts = texts.to(device)
        
        # Forward pass
        image_emb, text_emb = model(images, texts)
        loss, accuracy = criterion(image_emb, text_emb)
        
        total_loss += loss.item()
        total_accuracy += accuracy.item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy.item():.4f}'})
    
    return total_loss / len(dataloader), total_accuracy / len(dataloader)

print("✓ Training functions defined")

## 12. Run Training

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0
best_model_path = OUTPUT_DIR / 'best_multimodal_model.pth'

print("="*60)
print("Starting Training")
print("="*60)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    
    # Update scheduler
    scheduler.step()
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print summary
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
    print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'vocab': vocab,
            'config': {
                'embedding_dim': EMBEDDING_DIM,
                'text_hidden_dim': TEXT_HIDDEN_DIM,
                'max_text_length': MAX_TEXT_LENGTH,
                'image_size': IMAGE_SIZE
            }
        }, best_model_path)
        print(f"✓ Best model saved (val_acc: {val_acc:.4f})")

print("\n" + "="*60)
print("Training Completed!")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print(f"Model saved to: {best_model_path}")
print("="*60)

## 13. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, len(history['train_loss']) + 1)

# Plot loss
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(epochs_range, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
axes[1].plot(epochs_range, history['val_acc'], 'r-', label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_history.png', dpi=150)
plt.show()

print(f"✓ Plot saved to {OUTPUT_DIR / 'training_history.png'}")

## 14. Save Vocabulary

In [None]:
vocab_path = OUTPUT_DIR / 'vocabulary.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)

print(f"✓ Vocabulary saved to {vocab_path}")

## 15. Convert to OpenVINO (Intel CPU Optimization)

In [None]:
import openvino as ov

# Load best model
checkpoint = torch.load(best_model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Extract separate encoders
image_encoder = model.image_encoder
text_encoder = model.text_encoder

# Convert image encoder
print("Converting image encoder to OpenVINO...")
dummy_image = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
image_ov_model = ov.convert_model(image_encoder, example_input=dummy_image)
ov.save_model(image_ov_model, OUTPUT_DIR / 'image_encoder.xml')
print(f"✓ Image encoder saved to {OUTPUT_DIR / 'image_encoder.xml'}")

# Convert text encoder
print("Converting text encoder to OpenVINO...")
dummy_text = torch.randint(0, vocab.n_words, (1, MAX_TEXT_LENGTH))
text_ov_model = ov.convert_model(text_encoder, example_input=dummy_text)
ov.save_model(text_ov_model, OUTPUT_DIR / 'text_encoder.xml')
print(f"✓ Text encoder saved to {OUTPUT_DIR / 'text_encoder.xml'}")

print("\n✓ Models converted to OpenVINO IR format for Intel CPU optimization!")

## Summary

### Files Created:
- `best_multimodal_model.pth` - PyTorch model checkpoint
- `vocabulary.pkl` - Text vocabulary
- `image_encoder.xml/.bin` - OpenVINO image encoder (Intel optimized)
- `text_encoder.xml/.bin` - OpenVINO text encoder (Intel optimized)
- `training_history.png` - Training curves

### Next Steps:
1. Build embeddings database for all images
2. Create inference notebook for image/text queries
3. Deploy using OpenVINO for fast CPU inference