In [13]:
# from google.colab import drive
# drive.mount('/content/drive/')
import gdown

gdown.download(id="10pHm5N-PFkTZqJ37MfcNpm7iV4Rjt7zz", output="merged_dataset.zip", quiet=False)

ModuleNotFoundError: No module named 'gdown'

In [23]:
import zipfile
with zipfile.ZipFile('merged_dataset.zip', 'r') as zip_ref:
    zip_ref.extractall()

In [1]:
import torch
torch.cuda.set_per_process_memory_fraction(0.4)
torch.cuda.empty_cache()

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

DATASET_PATH = 'merged_dataset'

all_classes = sorted(os.listdir(DATASET_PATH))

train_classes, test_classes = train_test_split(all_classes, test_size=0.1, random_state=42)

print(f'train classes: {len(train_classes)}, test classes: {len(test_classes)}')

class LogoDataset(Dataset):
    def __init__(self, class_list, transform=None):
        self.class_list = class_list
        self.transform = transform
        self.data = {}

        for class_id in class_list:
            class_path = os.path.join(DATASET_PATH, class_id)
            images = [os.path.join(class_path, img) for img in os.listdir(class_path)]
            if len(images) > 1: 
                self.data[class_id] = images
        
        self.class_ids = list(self.data.keys())

    def __len__(self):
        return sum(len(imgs) for imgs in self.data.values())

    def __getitem__(self, index):
        anchor_class = random.choice(self.class_ids)
        anchor_img = random.choice(self.data[anchor_class])
        positive_img = random.choice(self.data[anchor_class])

        negative_class = random.choice([c for c in self.class_ids if c != anchor_class])
        negative_img = random.choice(self.data[negative_class])

        anchor_img = Image.open(anchor_img).convert("RGB")
        positive_img = Image.open(positive_img).convert("RGB")
        negative_img = Image.open(negative_img).convert("RGB")

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomPerspective(distortion_scale=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = LogoDataset(train_classes, train_transform)
test_dataset = LogoDataset(test_classes, val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f'train: {len(train_loader.dataset)}, test: {len(test_loader.dataset)}')

class LogoEncoder(nn.Module):
    def __init__(self, embedding_dim=1024):
        super(LogoEncoder, self).__init__()
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.model.fc = nn.Linear(self.model.fc.in_features, embedding_dim)
        self.normalize = nn.functional.normalize

    def forward(self, x):
        x = self.model(x)
        x = F.normalize(x, p=2, dim=1)
        return x

class TripletLossCosine(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLossCosine, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_sim = F.cosine_similarity(anchor, positive)
        neg_sim = F.cosine_similarity(anchor, negative)
        loss = F.relu(self.margin - pos_sim + neg_sim)
        return loss.mean()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = LogoEncoder(embedding_dim=1024).to(device)
criterion = TripletLossCosine(margin=0.4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float('inf')
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    val_loss = 0
    for anchor, positive, negative in tqdm(train_loader):
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        anchor_embed = model(anchor)
        positive_embed = model(positive)
        negative_embed = model(negative)

        loss = criterion(anchor_embed, positive_embed, negative_embed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader)
    with torch.no_grad():
        for anchor, positive, negative in tqdm(test_loader):
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
    
            anchor_embed = model(anchor)
            positive_embed = model(positive)
            negative_embed = model(negative)
    
            loss = criterion(anchor_embed, positive_embed, negative_embed)    
            val_loss += loss.item()
    val_loss /= len(test_loader)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        
    print(f'epoch {epoch+1}/{num_epochs}, train loss: {train_loss:.4f} val loss: {val_loss:.4f}')

Train classes: 2700, Test classes: 300
Train: 175403, Val: 20530, Test: 18860


100%|██████████| 5482/5482 [21:42<00:00,  4.21it/s]
100%|██████████| 590/590 [00:53<00:00, 11.12it/s]


epoch 1/15, train loss: 0.1852 val loss: 0.1489


100%|██████████| 5482/5482 [21:37<00:00,  4.23it/s]
100%|██████████| 590/590 [00:52<00:00, 11.15it/s]


epoch 2/15, train loss: 0.1226 val loss: 0.1326


100%|██████████| 5482/5482 [21:37<00:00,  4.23it/s]
100%|██████████| 590/590 [00:53<00:00, 11.03it/s]


epoch 3/15, train loss: 0.0952 val loss: 0.1240


  9%|▉         | 493/5482 [01:57<20:10,  4.12it/s]