In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import MSELoss
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt
import os
import json
import numpy as np
from tqdm import tqdm
import neptune
import importlib
from gallery_detection_models import models


## Parameters

In [7]:
PARAMETERS = {
    "dataset_folder_path": "/media/lorenzo/SAM500/datasets/gallery_detection_dataset",
    "n_samples": None,
    "batch_size": 64,
    "n_epochs": 32,
    "lr": 0.00004,
    "lr_decay": 0.99,
    "save_folder": "/media/lorenzo/SAM500/models/gallery-detection/",
}



## Load dataset

In [8]:
class GalleryDetectionDataset(Dataset):
    def __init__(self, index, n_desired_samples=None):
        self.index = index
        self.get_total_datapoints()
        if n_desired_samples is None:
            self.n_desired_samples = self.n_available_samples
        else:
            if self.n_available_samples > n_desired_samples:
                self.n_desired_samples = n_desired_samples
            else:
                self.n_desired_samples = self.n_available_samples
        self.set_n_samples_per_world()
        self.load()

    def get_total_datapoints(self):
        data = self.index["data"]
        self.n_available_samples = 0
        for world_name in data.keys():
            self.n_available_samples += data[world_name]["n_datapoints"]

    def set_n_samples_per_world(self):
        self.n_samples_per_world = {}
        for world_name in self.index["data"].keys():
            n_samples_in_world = self.index["data"][world_name]["n_datapoints"]
            self.n_samples_per_world[world_name] = int(
                np.round(n_samples_in_world * self.n_desired_samples / self.n_available_samples)
            )
        self.final_n_datapoints = sum(
            self.n_samples_per_world[k] for k in self.n_samples_per_world.keys()
        )
        print(self.n_samples_per_world)

    def load(self):
        print("Allocating memory")
        self.images = torch.zeros(
            (self.final_n_datapoints, 1, 16, self.index["info"]["image_width"])
        )
        self.labels = torch.zeros((self.final_n_datapoints, 360))
        global_index = 0
        with tqdm(total=self.final_n_datapoints) as pbar:
            for world_name in self.index["data"].keys():
                folder_name = self.index["data"][world_name]["images_folder"]
                samples_to_load = self.n_samples_per_world[world_name]
                path_to_world_folder = os.path.join(
                    self.index["info"]["path_to_dataset"], folder_name
                )
                assert os.path.exists(path_to_world_folder)
                raw_idxs = np.arange(0, self.index["data"][world_name]["n_datapoints"])
                np.random.shuffle(raw_idxs)
                idxs = raw_idxs[:samples_to_load]
                for idx in idxs:
                    file_name = f"{idx:010d}.npz"
                    path_to_file = os.path.join(path_to_world_folder, file_name)
                    data = np.load(path_to_file)
                    self.images[global_index, 0, :, :] = torch.tensor(data["image"])
                    self.labels[global_index] = torch.tensor(data["label"])
                    if torch.any(torch.isnan(self.labels[global_index])):
                        print("There is a NaN in the raw label")
                    if torch.any(torch.isnan(self.labels[global_index])):
                        print("There is a NaN after normalizing")
                    global_index += 1
                    pbar.update(1)
        print("Dataset loaded")

    def __len__(self):
        return self.final_n_datapoints

    def __getitem__(self, index):
        img, lbl = self.images[index], self.labels[index]
        if torch.any(torch.isnan(lbl)):
            print("Nan in __getitem__")
        return img, lbl

In [9]:
dataset_folder_path = PARAMETERS["dataset_folder_path"]
path_to_index_file = os.path.join(dataset_folder_path, "index.json")
assert os.path.exists(path_to_index_file)
with open(path_to_index_file, "r") as f:
    index = json.load(f)
n_samples = PARAMETERS["n_samples"]
batch_size = PARAMETERS["batch_size"]
dataset = GalleryDetectionDataset(index, n_desired_samples=n_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

{'env_018': 21777, 'env_007': 13324, 'env_010': 21551, 'env_020': 48033, 'env_017': 21721, 'env_002': 19991, 'env_015': 18613, 'env_008': 27916, 'env_003': 23698, 'env_016': 27254, 'env_005': 31794, 'env_013': 18436, 'env_019': 36708, 'env_004': 28563, 'env_011': 17309, 'env_014': 27147, 'env_006': 34637, 'env_009': 21455, 'env_012': 38653, 'env_001': 24794}
Allocating memory


100%|██████████| 523374/523374 [17:18<00:00, 504.09it/s]

Dataset loaded





## Data Augmentation

In [None]:
from torchvision.transforms import Compose, RandomErasing
def augment(inpt: torch.Tensor):
    """The tensor should be of the shape B x C x H x W. Where:
        - B: Batch size
        - C: N channels
        - H: Image height
        - W: Image width
    """
    


### Train

In [10]:
print("Getting parameters")
n_epochs = PARAMETERS["n_epochs"]
lr = PARAMETERS["lr"]
lr_decay = PARAMETERS["lr_decay"]
save_folder = PARAMETERS["save_folder"]
os.makedirs(save_folder, exist_ok=True)

Getting parameters


In [29]:
importlib.reload(models)
print("Starting neptune run")
run = neptune.init_run(
    project="lcano/gallery-detection",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiYjcxZGU4OC00ZjVkLTRmMDAtYjBlMi0wYzkzNDQwOGJkNWUifQ==",
    capture_stderr=False,
    capture_stdout=False,
)  # your credentials
model = models.GalleryDetectorV3(debug=False)
#model.init_weights()
model = model.type(torch.float)
model = model.to("cuda")
optimizer = Adam(model.parameters(), lr=lr)
lr_scheduler = ExponentialLR(optimizer, lr_decay)
criterion = MSELoss(reduction="mean")
save_file_path = os.path.join(save_folder, f"{model.__class__.__name__}.v3.torch")
print(f"Saving at: {save_file_path}")
with tqdm(total = n_epochs * len(dataloader)) as pbar:
    for n_epoch in range(n_epochs):
        epoch_avg_loss = 0
        n_batches = 0
        for batch_data in dataloader:
            n_batches += 1
            img, lbl = batch_data
            if torch.any(torch.isnan(lbl)):
                print("Break for NaN in label pre-cuda")
                break
            img = img.to("cuda").type(torch.float)
            lbl = lbl.to("cuda").type(torch.float)
            optimizer.zero_grad()
            pred = model(img)
            if torch.max(pred) == 0:
                print("Break for collapse")
                break
            if torch.any(torch.isnan(img)):
                print("Break for NaN in image")
                break
            if torch.any(torch.isnan(lbl)):
                print("Break for NaN in label")
                break
            if torch.any(torch.isnan(pred)):
                print("Break for NaN in prediction")
                break
            loss = criterion(lbl, pred)
            loss.backward()
            optimizer.step()
            epoch_avg_loss += loss.item()
            run["train/loss"].append(loss.item())
            pbar.update(1)
        else:
            lr_scheduler.step()
            epoch_avg_loss /= n_batches
            if epoch_avg_loss < 0.03:
                torch.save(model.to("cpu").state_dict(), save_file_path)
                model.to("cuda")
            fig = plt.figure()
            axes1 = fig.add_axes([0, 0, 1, 1])
            axes1.imshow(torch.clone(img[0][0]).detach().cpu().numpy())
            axes2 = fig.add_axes([0, 1, 1, 1])
            axes2.plot(torch.clone(pred[0].detach().cpu()).numpy())
            axes2.plot(torch.clone(lbl[0].detach().cpu()).numpy())
            run["predictions"].append(fig)
            plt.close()
            continue
        break
run.stop()

Starting neptune run
https://app.neptune.ai/lcano/gallery-detection/e/GAL-117
Saving at: /media/lorenzo/SAM500/models/gallery-detection/GalleryDetectorV3.v3.torch


  0%|          | 3156/2093568 [03:01<33:21:37, 17.41it/s]