In [None]:
import os
import pandas as pd
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F

from cotatenis_sneakers.sneaker_dataset import SneakerDataset
from cotatenis_sneakers.sneaker_transforms import get_transform, UnNormalize

In [None]:
_download_ = False
folder = "data/public"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print("device:", device)

## 1 - Data download and import

In [None]:
if _download_:
    os.system("python download_data.py")
    os.system("data/prepare_data.py")

In [None]:
transform = get_transform()

In [None]:
train_data = pd.read_csv(f"{folder}/train/train.csv")
train_dataset = SneakerDataset(
    train_data, folder=f"{folder}/train", device=device, transform=transform
)
test_data = pd.read_csv(f"{folder}/test/test.csv")
test_dataset = SneakerDataset(
    test_data, folder=f"{folder}/test", device=device, transform=transform
)

In [None]:
train_dataset[0]

In [None]:
len(train_dataset), len(test_dataset)

## 2 - Data Visualization

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
for i in range(4):
    for j in range(4):
        rint = np.random.randint(train_dataset.data.shape[0])
        img, brand = train_dataset.get_untransformed_tuple(rint)
        ax[i, j].imshow(img)
        ax[i, j].set_title(brand)
        ax[i, j].axis("off")
plt.show()

In [None]:
# labels distribution
brands = train_dataset.labels.value_counts()
plt.pie(brands, labels=brands.index, autopct="%1.1f%%")
plt.title("Brands distribution")
plt.show()

In [None]:
print(
    "All images have the same dimensions:",
    all(
        [
            train_dataset.get_untransformed_tuple(i)[0].size
            == train_dataset.get_untransformed_tuple(0)[0].size
            for i in range(len(train_dataset))
        ]
    ),
)

In [None]:
sizes = [
    str(train_dataset.get_untransformed_tuple(i)[0].size)
    for i in range(len(train_dataset))
]


sizes = pd.Series(sizes).value_counts()


plt.bar(sizes.index, sizes.values)


plt.xticks(rotation=90)


plt.title("Images dimensions distribution")


plt.show()

## 3 - Data Preprocessing

Preprocessing is done thanks to the `transform` parameter of the sneaker dataset. You can find the details of each step in `cotanis_sneakers/sneaker_transforms.py`. We pad images so they are the same size, and also normalise them.

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 5))

axs[0].imshow(test_dataset.get_untransformed_tuple(752)[0])
axs[0].set_title("Original image")

axs[1].imshow(F.to_pil_image(test_dataset[752][0]))
axs[1].set_title("Transformed image")

unnorm = UnNormalize()
axs[2].imshow(F.to_pil_image(unnorm(test_dataset[752][0])))
axs[2].set_title("Transformed image\nwith denormalisation")

## 4 - Prediction

In [None]:
if device == "cuda":
    test_loader = DataLoader(
        test_dataset, batch_size=32, shuffle=False, pin_memory=True
    )
else:
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# load a pretrained model for image classification
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
model.fc = torch.nn.Linear(2048, 3)
model = model.to(device)

model.eval()
correct = 0
total_correct = 0
total = 0


print_every = 10

with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == torch.max(labels, 1)[1]).sum().item()

        if i % print_every == 0 and i != 0:
            print(f"Iteration {i}, accuracy: {correct / total}")
            total_correct += correct
            correct = 0
            total = 0


print("Total correct", total_correct)
print("Total images", test_dataset.data.shape[0])
accuracy = total_correct / test_dataset.data.shape[0]
print(f"Test Accuracy: {accuracy}")