In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import ToTensor, Lambda, Normalize, CenterCrop
from torchvision.io.image import ImageReadMode
from torch import nn
import torchvision.models as models

from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cv2


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/MyDrive/

In [None]:
ls

In [None]:
!pip install import-ipynb
import import_ipynb

In [None]:
from Custom_Read_Data import CustomReadData
from Custom_image_dataset import CustomImageDataset
from Train import Model_Training
import densenet 
import vgg_net
import resnet
import Inception_V3

In [None]:
#CUDA for Pytorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#device = 'cpu'
print(f'using {device} device')
#torch.backends.cudnn.benchmark = True

In [None]:
data_path = './trainval/'
labels_path = './labels.csv'

read_data = CustomReadData(data_path, labels_path)
files, gd = read_data.im_read()

In [None]:
len(files)

In [None]:
params_training = {
    'batch_size': 8,
    'shuffle': True,
    #'num_workers': 2
}

In [None]:
params_validation = {
    'batch_size': 1,
    'shuffle': True,
    #'num_workers': 2
}

In [None]:
training_set = CustomImageDataset(files[:6000], gd[:6000])
validation_set = CustomImageDataset(files[6000:], gd[6000:])

#Generator
training_generator = DataLoader(training_set, **params_training)
validation_generator = DataLoader(validation_set, **params_validation)

In [None]:
model = resnet.ResNet182(img_channel=3, num_classes=3)
model.to(device)

In [None]:
#Hyper-parameters

learning_rate = 1e-3
epochs = 150

In [None]:
epoch = 0
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
FILE = './checkpoints_resnet_182.pth'
loss_fn = nn.CrossEntropyLoss()

train_func = Model_Training(training_generator, validation_generator, model, loss_fn, optimizer, device)

for t in range(epochs):
    print(f"Epoch {epoch + t + 1}\n-------------------------------")
    train_loss = train_func.train_loop()
    valid_loss = train_func.validation_loop()

    checkpoint = {
        "epoch": epoch + t + 1,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict()
        }
    torch.save(checkpoint , FILE)
print("Done!")

In [None]:
PATH = './res_net182.pth'
torch.save(model.state_dict(), PATH)

In [None]:
# Optional method
## To load data through checkpoint

loaded_checkpoint = torch.load("./checkpoints_resnet_182.pth")
epoch = loaded_checkpoint["epoch"]
print(epoch)

model = resnet.ResNet182(img_channel=3, num_classes=3)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.load_state_dict(loaded_checkpoint["model_state"])
optimizer.load_state_dict(loaded_checkpoint["optim_state"])
