In [1]:
# Fashion Recognition Model (Shape, Fabric, Color)

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt


In [2]:
# 1. Constants
IMG_DIR = '../DeepFashionData/images'
SHAPE_PATH = '../DeepFashionData/labels/shape/shape_anno_all.txt'
FABRIC_PATH = '../DeepFashionData/labels/texture/fabric_ann.txt'
PATTERN_PATH = '../DeepFashionData/labels/texture/pattern_ann.txt'
SEGM_DIR = '../DeepFashionData/segm'

SHAPE_ATTR_CLASSES = [6, 5, 4, 3, 5, 3, 3, 3, 5, 7, 3, 3]
FABRIC_CLASSES = 8
PATTERN_CLASSES = 8

In [3]:
# 2. Label Loading Functions
def load_shape_labels(path):
    labels = {}
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            name = parts[0]
            attrs = list(map(int, parts[1:]))
            labels[name] = attrs
    return labels

def load_triplet_labels(path):
    labels = {}
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            name = parts[0]
            labels[name] = list(map(int, parts[1:]))
    return labels

In [4]:
# 3. Custom Dataset
class FashionDataset(Dataset):
    def __init__(self, img_dir, segm_dir, shape_labels, fabric_labels, pattern_labels, transform=None):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.shape_labels = shape_labels
        self.fabric_labels = fabric_labels
        self.pattern_labels = pattern_labels
        self.transform = transform

        self.filenames = [f for f in shape_labels.keys() if f in fabric_labels and f in pattern_labels
                           and os.path.exists(os.path.join(segm_dir, f.replace('.jpg', '_segm.png')))]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        img_path = os.path.join(self.img_dir, name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        shape = torch.tensor(self.shape_labels[name], dtype=torch.long)
        fabric = torch.tensor(self.fabric_labels[name], dtype=torch.long)
        pattern = torch.tensor(self.pattern_labels[name], dtype=torch.long)

        return image, {'shape': shape, 'fabric': fabric, 'pattern': pattern, 'filename': name}


In [5]:
# 4. Model
class FashionRecognitionModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(base.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()

        self.shape_heads = nn.ModuleList([nn.Linear(2048, c) for c in SHAPE_ATTR_CLASSES])
        self.fabric_heads = nn.ModuleList([nn.Linear(2048, FABRIC_CLASSES) for _ in range(3)])
        self.pattern_heads = nn.ModuleList([nn.Linear(2048, PATTERN_CLASSES) for _ in range(3)])

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x)
        x = self.flat(x)
        shape_out = [head(x) for head in self.shape_heads]
        fabric_out = [head(x) for head in self.fabric_heads]
        pattern_out = [head(x) for head in self.pattern_heads]
        return {'shape': shape_out, 'fabric': fabric_out, 'pattern': pattern_out}


In [6]:
# 5. Loss Function
def compute_loss(preds, targets):
    loss_fn = nn.CrossEntropyLoss()
    shape_loss = sum(loss_fn(p, t) for p, t in zip(preds['shape'], targets['shape'].T))
    fabric_loss = sum(loss_fn(p, t) for p, t in zip(preds['fabric'], targets['fabric'].T))
    pattern_loss = sum(loss_fn(p, t) for p, t in zip(preds['pattern'], targets['pattern'].T))
    return shape_loss + fabric_loss + pattern_loss


In [7]:
# 6. Training Loop
def train_model(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch_idx, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        labels = {k: v.to(device) for k, v in labels.items() if k != 'filename'}
        optimizer.zero_grad()
        outputs = model(images)
        loss = compute_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            print(f"Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")

    print(f"Average Training Loss: {total_loss / len(dataloader):.4f}")


In [8]:
# 7. Evaluation Function
def evaluate_model(model, dataloader, device):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            targets = {k: v.to(device) for k, v in labels.items() if k != 'filename'}
            outputs = model(images)
            for key in ['shape', 'fabric', 'pattern']:
                for pred, target in zip(outputs[key], targets[key].T):
                    _, pred_class = torch.max(pred, dim=1)
                    correct += (pred_class == target).sum().item()
                    total += target.numel()
    print(f"Evaluation Accuracy: {100 * correct / total:.2f}%")


In [9]:
# 8. Save and Load

def save_model(model, path='fashion_model.pth'):
    torch.save(model.state_dict(), path)

def load_model(model, path='fashion_model.pth'):
    model.load_state_dict(torch.load(path))
    model.eval()


In [22]:
# 9. Main
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Loading labels...")
    shape_labels = load_shape_labels(SHAPE_PATH)
    fabric_labels = load_triplet_labels(FABRIC_PATH)
    color_labels = load_triplet_labels(PATTERN_PATH)

    transform = transforms.Compose([
        transforms.Resize((256, 192)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    print("Creating dataset...")
    full_dataset = FashionDataset(IMG_DIR, SEGM_DIR, shape_labels, fabric_labels, color_labels, transform)
    train_size = int(0.9 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    print(f"Total samples: {len(full_dataset)}, Train: {len(train_dataset)}, Test: {len(test_dataset)}")

    model = FashionRecognitionModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(5):
        print(f"\nEpoch {epoch+1}/5 started")
        train_model(model, train_loader, optimizer, device)
        print(f"Epoch {epoch+1} training complete. Starting evaluation...")
        evaluate_model(model, test_loader, device)
        print(f"Epoch {epoch+1} evaluation complete.")

    print("Saving model...")
    save_model(model)
    print("Model saved.")


Loading labels...
Creating dataset...
Total samples: 12694, Train: 11424, Test: 1270





Epoch 1/5 started
Batch 10/1428, Loss: 17.7317
Batch 20/1428, Loss: 13.0780
Batch 30/1428, Loss: 16.0402
Batch 40/1428, Loss: 14.9717
Batch 50/1428, Loss: 14.4167
Batch 60/1428, Loss: 14.0497
Batch 70/1428, Loss: 11.1688
Batch 80/1428, Loss: 9.2579
Batch 90/1428, Loss: 13.2333
Batch 100/1428, Loss: 13.2415
Batch 110/1428, Loss: 8.3215
Batch 120/1428, Loss: 12.4942
Batch 130/1428, Loss: 11.0418
Batch 140/1428, Loss: 11.1625
Batch 150/1428, Loss: 12.1272


KeyboardInterrupt: 

In [28]:
def predict_single_image(model, image_path, transform, device):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)

    # Process predictions
    shape_preds = [torch.argmax(p).item() for p in outputs['shape']]
    fabric_preds = [torch.argmax(p).item() for p in outputs['fabric']]
    color_preds = [torch.argmax(p).item() for p in outputs['pattern']]

    print("Prediction Results for:", image_path)
    print("\n--- Shape Attributes ---")
    for i, pred in enumerate(shape_preds):
        print(f"Attr {i+1}: Class {pred}")

    print("\n--- Fabric (Upper, Lower, Outer) ---")
    for i, pred in enumerate(fabric_preds):
        print(f"{['Upper','Lower','Outer'][i]}: {FABRIC_LABELS_LIST[pred]}")

    print("\n--- Color (Upper, Lower, Outer) ---")
    for i, pred in enumerate(color_preds):
        print(f"{['Upper','Lower','Outer'][i]}: {COLOR_LABELS_LIST[pred]}")


In [32]:
FABRIC_LABELS_LIST = [
    'Cotton', 'Denim', 'Leather', 'Linen', 'Silk', 'Wool', 'Polyester', 'Other'
]

COLOR_LABELS_LIST = [
    'Black', 'White', 'Red', 'Blue', 'Green', 'Yellow', 'Brown', 'Other'
]

predict_single_image(model, 'test/test1.jpg', transform, device)


Prediction Results for: test/test1.jpg

--- Shape Attributes ---
Attr 1: Class 1
Attr 2: Class 3
Attr 3: Class 0
Attr 4: Class 0
Attr 5: Class 0
Attr 6: Class 0
Attr 7: Class 0
Attr 8: Class 1
Attr 9: Class 3
Attr 10: Class 2
Attr 11: Class 1
Attr 12: Class 1

--- Fabric (Upper, Lower, Outer) ---
Upper: Denim
Lower: Cotton
Outer: Other

--- Color (Upper, Lower, Outer) ---
Upper: Blue
Lower: Blue
Outer: Other
