In [1]:
import h5py
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel, ViTModel, ViTImageProcessor

2025-11-12 04:19:19.594955: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762921159.617362     137 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762921159.624317     137 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [2]:
# =============================================================================
# Dataset
# =============================================================================
class FashionGenDataset(Dataset):
    def __init__(self, h5_path, visualize=False):
        self.h5 = h5py.File(h5_path, 'r')
        self.texts = self.h5['input_description']
        self.imgs = self.h5['input_image']
        self.visualize = visualize
        self.fe = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        caption = self.texts[idx][0].decode('utf-8', errors='ignore')
        raw_img = Image.fromarray(self.imgs[idx])
        cropped_img = self.crop_and_resize_img(raw_img)
        pixel_values = self.fe(images=cropped_img, return_tensors="pt")['pixel_values'].squeeze(0)
        return caption, pixel_values

    @staticmethod
    def crop_and_resize_img(img, threshold=245):
        img_np = np.array(img)
        non_white_mask = np.any(img_np < threshold, axis=2)
        if not np.any(non_white_mask):
            return img.resize((224, 224))
        coords = np.argwhere(non_white_mask)
        y0, x0 = coords.min(axis=0)
        y1, x1 = coords.max(axis=0) + 1
        cropped_img = img.crop((x0, y0, x1, y1))
        resized_img = cropped_img.resize((224, 224))
        return resized_img

In [3]:
# =============================================================================
# Encoders
# =============================================================================
class TextEncoder(nn.Module):
    def __init__(self, concept_terms, proj_dim=500, sim_threshold=0.7, freeze_bert=False):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.proj = nn.Linear(768, proj_dim)
        self.threshold = sim_threshold
        self.concept_terms = concept_terms
        self.freeze_bert = freeze_bert

        with torch.no_grad():
            concept_tokens = self.tokenizer(concept_terms, return_tensors="pt", padding=True, truncation=True)
            concepts_output = self.bert(**concept_tokens).last_hidden_state
            concepts_mask = concept_tokens["attention_mask"].unsqueeze(-1)
            concepts_avg = (concepts_output * concepts_mask).sum(dim=1) / concepts_mask.sum(dim=1)
        self.register_buffer("concept_raw_embeds", concepts_avg)

    def forward(self, texts):
        device = next(self.parameters()).device
        tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
        attn_mask = tokens["attention_mask"].unsqueeze(-1)

        if self.freeze_bert:
            with torch.no_grad():
                output = self.bert(**tokens).last_hidden_state
        else:
            output = self.bert(**tokens).last_hidden_state

        output_norm = F.normalize(output, dim=-1)
        concept_norm = F.normalize(self.concept_raw_embeds, dim=-1)
        sim = torch.matmul(output_norm, concept_norm.T)

        enriched = []
        for b in range(sim.size(0)):
            enriched_words = []
            for l in range(sim.size(1)):
                if attn_mask[b, l] == 0:
                    continue
                sims = sim[b, l]
                valid_mask = sims >= self.threshold
                if valid_mask.any():
                    max_idx = sims[valid_mask].argmax()
                    concept_idx = valid_mask.nonzero(as_tuple=False)[max_idx]
                    best_concept = self.concept_raw_embeds[concept_idx]
                    enriched_word = (output[b, l] + best_concept) / 2
                else:
                    enriched_word = output[b, l]
                enriched_words.append(enriched_word)
            enriched_sentence = torch.stack(enriched_words).mean(dim=0)
            enriched.append(enriched_sentence)

        enriched = torch.stack(enriched)
        return self.proj(enriched)


class ImageEncoder(nn.Module):
    def __init__(self, proj_dim=500, topk=8):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224', output_attentions=True)
        self.proj = nn.Linear(768, proj_dim)
        self.topk = topk

    def forward(self, pixel_values):
        vit_out = self.vit(pixel_values=pixel_values)
        out = vit_out.last_hidden_state
        attn = vit_out.attentions[-1]
        patches = out[:, 1:]
        scores = attn[:, :, 0, 1:].mean(dim=1)
        k = min(self.topk, patches.size(1))
        idx = scores.topk(k, dim=1).indices
        selected = patches.gather(1, idx.unsqueeze(-1).expand(-1, -1, patches.size(2)))
        avg = selected.mean(dim=1)
        return self.proj(avg)

In [4]:
# =============================================================================
# SPF Module (Corrected)
# =============================================================================
class SemanticProgressiveFusionModule(nn.Module):
    def __init__(self, hidden_dim=500, num_prototypes=32, tau1=0.07, tau2=0.1, 
                 eta=1.0, gamma1=0.5, beta1=0.5):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_prototypes = num_prototypes
        
        self.tau1 = nn.Parameter(torch.tensor(tau1))
        self.tau2 = nn.Parameter(torch.tensor(tau2))
        self.eta = nn.Parameter(torch.tensor(eta))
        self.gamma1 = nn.Parameter(torch.tensor(gamma1))
        self.beta1 = nn.Parameter(torch.tensor(beta1))
        
        # Stage 1
        self.W_text_1 = nn.Linear(hidden_dim, hidden_dim)
        self.W_image_1 = nn.Linear(hidden_dim, hidden_dim)
        
        # Stage 2
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, hidden_dim))
        nn.init.xavier_uniform_(self.prototypes)
        
        self.W_image_2 = nn.Linear(hidden_dim, hidden_dim)
        self.W_text_2 = nn.Linear(hidden_dim, hidden_dim)
        
        # Stage 3
        self.text_out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.image_out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.norm_text = nn.LayerNorm(hidden_dim)
        self.norm_image = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
    def compute_entropy(self, attention_weights):
        eps = 1e-8
        attention_weights = attention_weights.clamp(min=eps)
        entropy = -torch.sum(attention_weights * torch.log(attention_weights + eps), dim=-1)
        return entropy
    
    def stage1_fine_grained_alignment(self, X_text, X_image):
        text_proj = self.W_text_1(X_text)
        image_proj = self.W_image_1(X_image)
        
        S_1 = torch.tanh(torch.matmul(text_proj.unsqueeze(1), image_proj.unsqueeze(2)).squeeze())
        A_1 = F.softmax(S_1 / self.tau1, dim=0)
        X_text_1 = X_text + self.gamma1 * A_1.unsqueeze(-1) * X_image
        entropy_1 = self.compute_entropy(A_1.unsqueeze(-1))
        
        return X_text_1, entropy_1, A_1
    
    def stage2_part_level_prototypes(self, X_text_1, X_image, A_1):
        P_2_norm = F.normalize(self.prototypes, dim=-1)
        image_proj = self.W_image_2(X_image)
        image_proj_norm = F.normalize(image_proj, dim=-1)
        
        proto_sim = torch.matmul(image_proj_norm, P_2_norm.t())
        R_2 = F.softmax(proto_sim / self.tau2, dim=-1)
        U_2 = torch.matmul(R_2, self.prototypes)
        
        text_proj = self.W_text_2(X_text_1)
        part_affinity = torch.matmul(text_proj.unsqueeze(1), U_2.unsqueeze(2)).squeeze()
        
        # ✅ CRITICAL: beta1 residual
        A_2 = torch.sigmoid(part_affinity) + self.beta1 * A_1
        X_text_2 = A_2.unsqueeze(-1) * U_2
        entropy_2 = self.compute_entropy(R_2)
        
        return X_text_2, X_image, entropy_2, R_2
    
    def stage3_entropy_regulated_fusion(self, X_text, X_image, 
                                       X_text_1, entropy_1,
                                       X_text_2, X_image_2, entropy_2):
        entropy_0 = torch.zeros_like(entropy_1)
        
        conf_0 = torch.exp(-entropy_0 / self.eta)
        conf_1 = torch.exp(-entropy_1 / self.eta)
        conf_2 = torch.exp(-entropy_2 / self.eta)
        
        total_conf = conf_0 + conf_1 + conf_2 + 1e-8
        
        omega_0 = conf_0 / total_conf
        omega_1 = conf_1 / total_conf
        omega_2 = conf_2 / total_conf
        
        X_text_fused = (omega_0.unsqueeze(-1) * X_text + 
                        omega_1.unsqueeze(-1) * X_text_1 + 
                        omega_2.unsqueeze(-1) * X_text_2)
        
        X_image_fused = (omega_0.unsqueeze(-1) * X_image + 
                         omega_1.unsqueeze(-1) * X_image + 
                         omega_2.unsqueeze(-1) * X_image_2)
        
        X_text_final = self.text_out_proj(X_text_fused)
        X_image_final = self.image_out_proj(X_image_fused)
        
        X_text_final = self.norm_text(X_text_final)
        X_image_final = self.norm_image(X_image_final)
        
        X_text_final = self.dropout(X_text_final)
        X_image_final = self.dropout(X_image_final)
        
        stats = {
            'omega_0': omega_0.mean().item(),
            'omega_1': omega_1.mean().item(),
            'omega_2': omega_2.mean().item(),
            'entropy_1': entropy_1.mean().item(),
            'entropy_2': entropy_2.mean().item()
        }
        
        return X_text_final, X_image_final, stats
    
    def forward(self, X_text, X_image):
        X_text_1, entropy_1, A_1 = self.stage1_fine_grained_alignment(X_text, X_image)
        X_text_2, X_image_2, entropy_2, R_2 = self.stage2_part_level_prototypes(
            X_text_1, X_image, A_1
        )
        X_text_final, X_image_final, stats = self.stage3_entropy_regulated_fusion(
            X_text, X_image, X_text_1, entropy_1, X_text_2, X_image_2, entropy_2
        )
        
        return X_text_final, X_image_final, stats

In [5]:
# =============================================================================
# Training Functions
# =============================================================================
def train_spf_module(text_encoder, image_encoder, train_loader, val_loader, 
                     epochs=10, lr=1e-4, device='cuda'):
    spf = SemanticProgressiveFusionModule(
        hidden_dim=500, num_prototypes=32, tau1=0.07, tau2=0.1,
        eta=1.0, gamma1=0.5, beta1=0.5
    ).to(device)
    
    optimizer = torch.optim.AdamW(spf.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    text_encoder.eval()
    image_encoder.eval()
    spf.train()
    
    print(f"Training SPF: {sum(p.numel() for p in spf.parameters()):,} params | LR: {lr} | Epochs: {epochs}\n")
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_align = 0.0
        num_batches = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for captions, images in progress_bar:
            images = images.to(device)
            batch_size = images.size(0)
            
            with torch.no_grad():
                text_feats = text_encoder(captions)
                image_feats = image_encoder(images)
            
            refined_text, refined_image, stats = spf(text_feats, image_feats)
            
            # Contrastive loss
            text_norm = F.normalize(refined_text, dim=-1)
            image_norm = F.normalize(refined_image, dim=-1)
            logits = torch.matmul(text_norm, image_norm.t()) / 0.07
            labels = torch.arange(batch_size, device=device)
            
            loss_t2i = F.cross_entropy(logits, labels)
            loss_i2t = F.cross_entropy(logits.t(), labels)
            contrastive_loss = (loss_t2i + loss_i2t) / 2
            
            # Alignment loss
            positive_sim = F.cosine_similarity(refined_text, refined_image, dim=1)
            alignment_loss = -positive_sim.mean()
            
            # Entropy regularization
            entropy_loss = stats['entropy_1'] + stats['entropy_2']
            
            # Diversity regularization
            proto_sim = torch.matmul(
                F.normalize(spf.prototypes, dim=-1),
                F.normalize(spf.prototypes, dim=-1).t()
            )
            proto_sim = proto_sim - torch.eye(spf.num_prototypes, device=device)
            diversity_loss = torch.abs(proto_sim).mean()
            
            total_loss = (
                1.0 * contrastive_loss +
                0.3 * alignment_loss +
                0.1 * entropy_loss +
                0.1 * diversity_loss
            )
            
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(spf.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += total_loss.item()
            epoch_align += positive_sim.mean().item()
            num_batches += 1
            
            progress_bar.set_postfix({
                'loss': f"{total_loss.item():.4f}",
                'align': f"{positive_sim.mean().item():.4f}",
                'ω1': f"{stats['omega_1']:.3f}",
                'ω2': f"{stats['omega_2']:.3f}"
            })
        
        avg_loss = epoch_loss / num_batches
        avg_align = epoch_align / num_batches
        
        print(f"\nEpoch {epoch+1}/{epochs}: Loss={avg_loss:.4f}, Align={avg_align:.4f}")
        
        scheduler.step()
        
        if val_loader is not None:
            val_loss, val_align = validate_spf(spf, text_encoder, image_encoder, val_loader, device)
            print(f"Val: Loss={val_loss:.4f}, Align={val_align:.4f}")
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': spf.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'config': {
                        'hidden_dim': 500,
                        'num_prototypes': 32,
                        'tau1': spf.tau1.item(),
                        'tau2': spf.tau2.item(),
                        'eta': spf.eta.item(),
                        'gamma1': spf.gamma1.item(),
                        'beta1': spf.beta1.item()
                    }
                }, 'spf_best.pth')
                print("✅ Best model saved!")
        print()
    
    return spf

def validate_spf(spf, text_encoder, image_encoder, val_loader, device):
    spf.eval()
    total_loss = 0.0
    total_align = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for captions, images in val_loader:
            images = images.to(device)
            batch_size = images.size(0)
            
            text_feats = text_encoder(captions)
            image_feats = image_encoder(images)
            refined_text, refined_image, _ = spf(text_feats, image_feats)
            
            text_norm = F.normalize(refined_text, dim=-1)
            image_norm = F.normalize(refined_image, dim=-1)
            logits = torch.matmul(text_norm, image_norm.t()) / 0.07
            labels = torch.arange(batch_size, device=device)
            
            loss_t2i = F.cross_entropy(logits, labels)
            loss_i2t = F.cross_entropy(logits.t(), labels)
            loss = (loss_t2i + loss_i2t) / 2
            
            alignment = F.cosine_similarity(refined_text, refined_image, dim=1).mean()
            
            total_loss += loss.item()
            total_align += alignment.item()
            num_batches += 1
    
    spf.train()
    return total_loss / num_batches, total_align / num_batches

In [6]:
# =============================================================================
# Main Training
# =============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

print("Loading dataset...")
full_ds = FashionGenDataset('/kaggle/input/fashiongen-validation/fashiongen_256_256_train.h5')

subset_indices = list(range(60000, 90000))
subset_ds = Subset(full_ds, subset_indices)

train_size = int(0.9 * len(subset_ds))
val_size = len(subset_ds) - train_size
train_ds, val_ds = random_split(subset_ds, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=2)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)}\n")

# Load pretrained encoders
ckpt = torch.load(
    '/kaggle/input/module-1-2-models-only-high-sim/trained_encoders_complete_only_high.pth',
    map_location=device
)

text_encoder = TextEncoder(
    ckpt['concept_terms'],
    proj_dim=ckpt['proj_dim']
).to(device)

image_encoder = ImageEncoder(
    proj_dim=ckpt['proj_dim']
).to(device)

text_encoder.load_state_dict(ckpt['text_encoder_state_dict'])
image_encoder.load_state_dict(ckpt['image_encoder_state_dict'])

print("Encoders loaded (frozen)\n")

# Train
print("="*80)
print("TRAINING SPF MODULE 3 (CORRECTED)")
print("="*80 + "\n")

trained_spf = train_spf_module(
    text_encoder, image_encoder, train_loader, val_loader,
    epochs=10, lr=1e-4, device=device
)

# Save final
torch.save({
    'model_state_dict': trained_spf.state_dict(),
    'config': {
        'hidden_dim': 500,
        'num_prototypes': 32,
        'tau1': trained_spf.tau1.item(),
        'tau2': trained_spf.tau2.item(),
        'eta': trained_spf.eta.item(),
        'gamma1': trained_spf.gamma1.item(),
        'beta1': trained_spf.beta1.item()
    }
}, 'spf_final.pth')

print("\n✅ Training completed!")

Device: cuda

Loading dataset...
Train: 27000 | Val: 3000



Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Encoders loaded (frozen)

TRAINING SPF MODULE 3 (CORRECTED)

Training SPF: 1,521,005 params | LR: 0.0001 | Epochs: 10



Epoch 1/10: 100%|██████████| 422/422 [10:07<00:00,  1.44s/it, loss=0.5010, align=0.5866, ω1=0.352, ω2=0.297]


Epoch 1/10: Loss=1.1310, Align=0.4857





Val: Loss=0.9430, Align=0.6202
✅ Best model saved!



Epoch 2/10: 100%|██████████| 422/422 [10:02<00:00,  1.43s/it, loss=0.6663, align=0.5861, ω1=0.354, ω2=0.287]


Epoch 2/10: Loss=0.6815, Align=0.5823





Val: Loss=0.7677, Align=0.6571
✅ Best model saved!



Epoch 3/10: 100%|██████████| 422/422 [10:01<00:00,  1.43s/it, loss=0.6654, align=0.6085, ω1=0.363, ω2=0.269]


Epoch 3/10: Loss=0.5744, Align=0.6104





Val: Loss=0.7021, Align=0.6841
✅ Best model saved!



Epoch 4/10: 100%|██████████| 422/422 [10:01<00:00,  1.43s/it, loss=0.4677, align=0.6195, ω1=0.379, ω2=0.242]


Epoch 4/10: Loss=0.5276, Align=0.6269





Val: Loss=0.6598, Align=0.6934
✅ Best model saved!



Epoch 5/10: 100%|██████████| 422/422 [10:02<00:00,  1.43s/it, loss=0.5621, align=0.6152, ω1=0.383, ω2=0.230]


Epoch 5/10: Loss=0.4926, Align=0.6351





Val: Loss=0.6342, Align=0.6995
✅ Best model saved!



Epoch 6/10: 100%|██████████| 422/422 [10:02<00:00,  1.43s/it, loss=0.2864, align=0.6481, ω1=0.376, ω2=0.248]


Epoch 6/10: Loss=0.4667, Align=0.6399





Val: Loss=0.6165, Align=0.7035
✅ Best model saved!



Epoch 7/10: 100%|██████████| 422/422 [10:02<00:00,  1.43s/it, loss=0.2453, align=0.6790, ω1=0.357, ω2=0.282]


Epoch 7/10: Loss=0.4405, Align=0.6432





Val: Loss=0.5991, Align=0.7072
✅ Best model saved!



Epoch 8/10: 100%|██████████| 422/422 [10:00<00:00,  1.42s/it, loss=0.2702, align=0.6436, ω1=0.381, ω2=0.238]


Epoch 8/10: Loss=0.4327, Align=0.6458





Val: Loss=0.5884, Align=0.7090
✅ Best model saved!



Epoch 9/10: 100%|██████████| 422/422 [10:01<00:00,  1.43s/it, loss=0.4445, align=0.6404, ω1=0.377, ω2=0.239]


Epoch 9/10: Loss=0.4192, Align=0.6467





Val: Loss=0.5830, Align=0.7094
✅ Best model saved!



Epoch 10/10: 100%|██████████| 422/422 [10:04<00:00,  1.43s/it, loss=0.2211, align=0.6627, ω1=0.371, ω2=0.251]


Epoch 10/10: Loss=0.4161, Align=0.6469





Val: Loss=0.5809, Align=0.7096
✅ Best model saved!


✅ Training completed!


In [7]:
hi

NameError: name 'hi' is not defined

In [8]:
!mkdir module_3_spf

In [None]:
!mv spf_final_corrected.pth spf_final.pth

In [9]:
!cp spf_best.pth module_3_spf/
!cp spf_final.pth module_3_spf/

In [10]:
import os
import json

# Step 1: Create the kaggle.json content
kaggle_token = {
    "username": "phanichaitanya349",
    "key": "91efb90726e30e3dd48368fddb50908a"
}

# Step 2: Save it to ~/.kaggle/kaggle.json
os.makedirs("/root/.kaggle", exist_ok=True)
with open("/root/.kaggle/kaggle.json", "w") as f:
    json.dump(kaggle_token, f)

# Step 3: Set proper permissions
os.chmod("/root/.kaggle/kaggle.json", 0o600)

print("✅ kaggle.json created and configured.")

✅ kaggle.json created and configured.


In [11]:
!kaggle datasets init -p module_3_spf

Data package template written to: module_3_spf/dataset-metadata.json


In [12]:
import json

metadata_path = "module_3_spf/dataset-metadata.json"

# Load existing metadata
with open(metadata_path, "r") as f:
    metadata = json.load(f)

# Set your dataset details
metadata["title"] = "Module 3 SPF 69-10e"
metadata["id"] = "phanichaitanya349/module-3-spf-69-10e"  # Must be lowercase with hyphens
metadata["licenses"] = [{"name": "CC0-1.0"}]  # Open license

# Save the modified metadata
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=4)

print("✅ Metadata updated successfully.")

✅ Metadata updated successfully.


In [13]:
!kaggle datasets create -p module_3_spf --dir-mode zip

Starting upload for file spf_best.pth
100%|██████████████████████████████████████| 17.4M/17.4M [00:00<00:00, 29.2MB/s]
Upload successful: spf_best.pth (17MB)
Starting upload for file spf_final.pth
100%|██████████████████████████████████████| 5.81M/5.81M [00:00<00:00, 12.4MB/s]
Upload successful: spf_final.pth (6MB)
Your private Dataset is being created. Please check progress at https://www.kaggle.com/datasets/phanichaitanya349/module-3-spf-69-10e
