# StyleSync: Multi-Label Fashion Classification (Streaming from S3)

This notebook trains a ResNet-50 model to classify fashion attributes using the Fashion Product Images Dataset.
It uses `s3torchconnector` to stream images directly from S3, avoiding local disk storage constraints.

In [None]:
# Cell 1: Install dependencies
# IMPORTANT: After running this cell, you MUST restart the runtime (Runtime -> Restart Runtime)
# for the installation to take effect.
!pip install s3torchconnector "smart_open[s3]" pandas matplotlib tqdm boto3

In [None]:
# Cell 2: Imports & Auth
import os
import io
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import boto3
from PIL import Image
from tqdm.auto import tqdm
from google.colab import userdata
from sklearn.preprocessing import MultiLabelBinarizer

# AWS S3 Connector
from s3torchconnector import S3MapDataset

# Torchvision
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Smart Open for S3 file handling
from smart_open import open as s_open

# Setup Credentials
# .strip() removes accidental whitespace from copy-pasting
os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID').strip()
os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY').strip()
os.environ['AWS_REGION'] = 'us-east-1'

BUCKET_NAME = "stylesync-mlops-data"
PREFIX = "style-sync/raw/fashion"
STYLES_CSV_KEY = f"s3://{BUCKET_NAME}/{PREFIX}/styles.csv"
IMAGES_PREFIX = f"s3://{BUCKET_NAME}/{PREFIX}/images/"

print("AWS Credentials loaded and environment configured.")

In [None]:
# Cell 3: Metadata Processing

def load_metadata(csv_s3_uri):
    print(f"Downloading metadata from {csv_s3_uri}...")
    
    # Create a boto3 session to explicitly pass credentials to smart_open
    session = boto3.Session(
        aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
        aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
        region_name=os.environ['AWS_REGION']
    )
    
    # Stream csv directly into pandas using smart_open with the boto3 client
    # This ensures the credentials are used correctly
    with s_open(csv_s3_uri, 'rb', transport_params={'client': session.client('s3')}) as f:
        df = pd.read_csv(f, on_bad_lines='skip')
    
    # Clean dataset: Ensure we have an ID and some attributes
    df = df.dropna(subset=['id', 'articleType', 'baseColour', 'season', 'usage'])
    df['id'] = df['id'].astype(str)
    
    # Create a combined 'tags' column for multi-label classification
    # We will predict: Article Type, Color, Season, Usage
    df['tags'] = df.apply(lambda x: [
        x['articleType'], 
        x['baseColour'], 
        x['season'], 
        x['usage']
    ], axis=1)
    
    return df

# Load Data
df = load_metadata(STYLES_CSV_KEY)

# Initialize MultiLabelBinarizer
mlb = MultiLabelBinarizer()
df['labels'] = list(mlb.fit_transform(df['tags']))

CLASSES = mlb.classes_
NUM_CLASSES = len(CLASSES)

print(f"Total Images: {len(df)}")
print(f"Total Classes: {NUM_CLASSES}")
print(f"Sample Classes: {CLASSES[:10]}")
print(f"Sample Row:\n{df.iloc[0][['id', 'tags']]}")

In [None]:
# Cell 4: Custom S3 Streaming Dataset

class S3FashionDataset(Dataset):
    def __init__(self, dataframe, bucket_prefix, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame containing 'id' and 'labels'.
            bucket_prefix (str): S3 prefix for images (e.g., s3://bucket/path/images/).
            transform (callable, optional): Transform to be applied on a sample.
        """
        self.df = dataframe
        self.bucket_prefix = bucket_prefix
        self.transform = transform
        
        # Create a mapping of index -> S3 Key
        # The S3MapDataset expects an iterable of S3 URIs
        self.image_uris = [f"{bucket_prefix}{row_id}.jpg" for row_id in self.df['id']]
        
        # Initialize the S3 Connector
        # We use S3MapDataset to handle the fetching logic efficiently
        self.s3_dataset = S3MapDataset.from_objects(self.image_uris)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 1. Fetch Image Bytes from S3 (Streaming)
        try:
            s3_object = self.s3_dataset[idx]
            image_bytes = s3_object.content
            
            # 2. Decode Image
            image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
            
            # 3. Apply Transforms
            if self.transform:
                image = self.transform(image)
                
            # 4. Get Labels
            # Labels are pre-computed in the dataframe as numpy arrays
            labels = torch.tensor(self.df.iloc[idx]['labels'], dtype=torch.float32)
            
            return image, labels
            
        except Exception as e:
            print(f"Error loading index {idx}: {e}")
            # Return a dummy tensor or handle error appropriately
            # For simplicity, we'll return zeros (not recommended for prod, but prevents crash)
            return torch.zeros((3, 224, 224)), torch.zeros(NUM_CLASSES)

# Define Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Instantiate Dataset
dataset = S3FashionDataset(df, IMAGES_PREFIX, transform=train_transform)

# Quick Test
print(f"Dataset Size: {len(dataset)}")

In [None]:
# Cell 5: Model Definition

class MultiLabelResNet(nn.Module):
    def __init__(self, num_classes):
        super(MultiLabelResNet, self).__init__()
        # Load Pretrained ResNet50
        self.resnet = models.resnet50(pretrained=True)
        
        # Replace the final Fully Connected Layer
        # ResNet50's fc layer has 2048 input features
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)
        
    def forward(self, x):
        return self.resnet(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiLabelResNet(num_classes=NUM_CLASSES)
model = model.to(device)

print(f"Model initialized on {device}")

In [None]:
# Cell 6: Training Loop & Checkpointing

def save_checkpoint(state, filename="checkpoint.pth"):
    """Saves checkpoint directly to S3"""
    s3_uri = f"s3://{BUCKET_NAME}/style-sync/models/{filename}"
    print(f"Saving checkpoint to {s3_uri}...")
    
    # Save to buffer first
    buffer = io.BytesIO()
    torch.save(state, buffer)
    buffer.seek(0)
    
    # Upload buffer to S3
    # We use the same session/client approach if needed, but smart_open usually handles it if env vars are correct
    # To be safe, let's use the explicit client again
    session = boto3.Session(
        aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
        aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
        region_name=os.environ['AWS_REGION']
    )
    with s_open(s3_uri, 'wb', transport_params={'client': session.client('s3')}) as f:
        f.write(buffer.read())
    print("Checkpoint saved.")

def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        
        for images, labels in loop:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            loop.set_postfix(loss=loss.item())
        
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
        
        # Save Checkpoint every epoch
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'classes': CLASSES
        }
        save_checkpoint(checkpoint, filename=f"resnet50_epoch_{epoch+1}.pth")

In [None]:
# Cell 7: Execution

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 10

# DataLoader
# num_workers=2 is usually safe for S3 streaming; too many might cause timeouts
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Start Training
train_model(model, dataloader, criterion, optimizer, num_epochs=EPOCHS)