In [None]:
# !/usr/bin/env python
# coding: utf-8

# ---- Library import ----

import pickle
from time import gmtime, strftime

import albumentations
import math
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ---- My utils ----
from utils.train_arguments import *
from utils.utils_data import *
from utils.utils_training import *

### Data

In [None]:
# Primero necesitamos reescalar (si usamos los coeficientes de Efficientnet) la resolucion de las imagenes a usar
args.crop_size = math.ceil(args.crop_size * args.resolution_coefficient)
args.img_size = math.ceil(args.img_size * args.resolution_coefficient)

In [None]:
train_aug = albumentations.Compose([
    albumentations.PadIfNeeded(p=1, min_height=args.crop_size, min_width=args.crop_size),
    albumentations.Resize(args.img_size, args.img_size),
    albumentations.RandomCrop(p=1, height=args.crop_size, width=args.crop_size)
])

val_aug = albumentations.Compose([
    albumentations.PadIfNeeded(p=1, min_height=args.crop_size, min_width=args.crop_size),
    albumentations.Resize(args.img_size, args.img_size),
    albumentations.CenterCrop(p=1, height=args.crop_size, width=args.crop_size)
])

In [None]:
if args.data_augmentation:
    print("Data Augmentation to be implemented...")

In [None]:
train_dataset = ISIC2019_Dataset(data_partition="train", albumentation=train_aug)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=True)

In [None]:
val_dataset = ISIC2019_Dataset(data_partition="validation", albumentation=val_aug)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=False)
print("Data loaded!\n")

### Model

In [None]:
num_classes = len(np.unique(ISIC_TRAIN_DF_TRUTH.target))
print("{} Classes detected!".format(num_classes))
model = model_selector(args.model_name, num_classes, args.depth_coefficient, args.width_coefficient)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

### Training

In [None]:
args.optimizer = "rmsprop"

In [None]:
for argument in args.__dict__:
    print("{}: {}".format(argument, args.__dict__[argument]))

In [None]:
progress_train_loss, progress_val_loss, progress_train_acc, progress_val_acc = [], [], [], []
best_loss, best_acc = 10e10, -1

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(args.optimizer, model, lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 135, 170], gamma=0.2)

In [None]:
lrs, losses = torchy.utils.findLR(model, optimizer, criterion, train_loader, verbose=False)

In [None]:
len(losses)

In [None]:
init, fin = 200, 543
plt.xlabel('Learning Rates')
plt.ylabel('Losses')
plt.plot(lrs[init:fin], losses[init:fin])
plt.savefig("data/lr_find_"+args.optimizer+"_custom.png")
plt.show()