# This notebook is a pretraining test environment for the AMI trap classifier

## Setup

### Set the working directory appropriately

In [1]:
import os

os.chdir('..')

### Import modules

In [2]:
import utils.config as c
import utils.mount as m
import utils.dataloader as dl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

### Mount ERDA

In [3]:
m.mount()

Mounting remote directory with command: rclone mount erda:/AMI_GBIF_Pretraining_Data/root erda-home --vfs-cache-mode full --vfs-read-chunk-size 1M --vfs-cache-max-age 15m --vfs-cache-max-size 50G --max-read-ahead 1M --dir-cache-time 15m --daemon
Mounted remote directory successfully


True

### !! Unmount !!

In [None]:
m.unmount()

## Define dataloader

In [16]:
dl.PretrainingImages(dl.root_dir)

In [4]:
batch_size, numworkers, pin_memory = dl.config['batch_size'], dl.config['num_workers'], dl.config['pin_memory']

dataset = dl.PretrainingDataset()

train_data, val_data, test_data = torch.utils.data.random_split(dataset, [int(len(dataset)*0.85), int(len(dataset)*0.05), int(len(dataset)*0.1)])

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=numworkers, pin_memory=pin_memory)
val_loader   = torch.utils.data.DataLoader(  val_data, batch_size=batch_size, shuffle=True, num_workers=numworkers, pin_memory=pin_memory)
test_loader  = torch.utils.data.DataLoader( test_data, batch_size=batch_size, shuffle=True, num_workers=numworkers, pin_memory=pin_memory)

AttributeError: 'PretrainingDataset' object has no attribute 'images'

### Test the dataloader by plotting some images from different species

In [None]:
print("Train size: ", len(train_loader))
print("Val size: ", len(val_loader))
print("Test size: ", len(test_loader))

# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()

# show images
plt.imshow(torchvision.utils.make_grid(images).permute(1,2,0))

## Define model

### Define optimizer, loss function and other hyperparameters

# Train