In [None]:
%%shell
gdown https://drive.google.com/uc?id=0B8-rUzbwVRk0c054eEozWG9COHM
unzip -qq Market-1501-v15.09.15.zip
mv /content/Market-1501-v15.09.15 /content/Market-1501
pip install wandb faiss-gpu
wandb login

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import os
from pathlib import Path
from tqdm import tqdm
import wandb

device = 'cuda' if torch.cuda.is_available() else 'cpu'
MARKET_DATA_DIR = '/root/Market-1501/'

In [None]:
class Market1501Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.jpg')])
        
        self.label_to_images = {}
        for image_file in self.image_files:
            label = int(image_file.split('_')[0])
            if label not in self.label_to_images:
                self.label_to_images[label] = []
            self.label_to_images[label].append(image_file)
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        anchor_image_file = self.image_files[idx]
        anchor_label = int(anchor_image_file.split('_')[0])
        
        positive_image_file = np.random.choice([img for img in self.label_to_images[anchor_label] if img != anchor_image_file])
        
        negative_label = np.random.choice([label for label in self.label_to_images.keys() if label != anchor_label and label != -1 and label != 0])
        negative_image_file = np.random.choice(self.label_to_images[negative_label])
        
        anchor_img = self.load_image(anchor_image_file)
        positive_img = self.load_image(positive_image_file)
        negative_img = self.load_image(negative_image_file)
        
        return (anchor_img, positive_img, negative_img), anchor_label
    
    def load_image(self, image_file):
        img_name = os.path.join(self.root_dir, image_file)
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image


transform = transforms.Compose([
    transforms.Resize((128, 64)),
    transforms.ToTensor(),
])


train_dataset = Market1501Dataset(os.path.join(MARKET_DATA_DIR, 'bounding_box_train'), transform=transform)
test_dataset = Market1501Dataset(os.path.join(MARKET_DATA_DIR, 'bounding_box_test'), transform=transform)
query_dataset = Market1501Dataset(os.path.join(MARKET_DATA_DIR, 'query'), transform=transform)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
query_loader = DataLoader(query_dataset, batch_size=128, shuffle=False)

In [None]:
import matplotlib.pyplot as plt

# def unnormalize(img):
#     mean = np.array([0.485, 0.456, 0.406])[:, None, None]
#     std = np.array([0.229, 0.224, 0.225])[:, None, None]
#     img = std * img + mean
#     img = np.clip(img, 0, 1)
#     return img

# Function to display an image
def imshow(img):
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0)) # change this line
    plt.imshow(img)
    plt.show()

(anchor_img, positive_img, negative_img), anchor_label = train_dataset[0]
print(f"Anchor label: {anchor_label}")
imshow(torchvision.utils.make_grid([anchor_img, positive_img, negative_img]))

In [None]:
(anchor_img, positive_img, negative_img), anchor_label = test_dataset[-1]
print(f"Identity: {anchor_label}")
imshow(torchvision.utils.make_grid([anchor_img]))

In [None]:
path_save_model = 'model'
Path(path_save_model).mkdir(parents=True, exist_ok=True)


model = models.resnet50(weights='DEFAULT')

embedding_size = 128
model.fc = nn.Linear(model.fc.in_features, embedding_size)
model = model.to(device)

num_epochs = 30


criterion = nn.TripletMarginLoss(margin=0.02)


optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=num_epochs)

def train(model, train_loader, criterion, optimizer, num_epochs):
    model.train()

    loss_all = []
    for epoch in range(num_epochs):
        loss_total = 0
        for i, (triplets, labels) in tqdm(enumerate(train_loader)):
     
            anchor, positive, negative = [img.to(device) for img in triplets]
            
            optimizer.zero_grad()
            
           
            anchor_embeddings = model(anchor)
            positive_embeddings = model(positive)
            negative_embeddings = model(negative)


            anchor_embeddings = nn.functional.normalize(anchor_embeddings, p=2, dim=1)
            positive_embeddings = nn.functional.normalize(positive_embeddings, p=2, dim=1)
            negative_embeddings = nn.functional.normalize(negative_embeddings, p=2, dim=1)
            
            
            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
            
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_total += loss.item()
            wandb.log({
                "loss": loss.item(),
                "lr": optimizer.param_groups[0]['lr'],
            })

        loss_batch = loss_total / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_batch}')
        wandb.log({"loss_batch": loss_batch})
        loss_all.append(loss_batch)
      
        torch.save(model, os.path.join(path_save_model, 'model_latest.pth'))
        if loss_batch == min(loss_all):
            torch.save(model, os.path.join(
                path_save_model, 'model_best_loss.pth'))

wandb.init(project="human-recognition")
train(model, train_loader, criterion, optimizer, num_epochs)