In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import transforms
import sys

In [None]:
base_dir = os.path.dirname(os.path.abspath('train_lenet.ipynb'))
tool_dir = os.path.dirname(base_dir)
sys.path.append(tool_dir)

In [None]:
from model.lenet import LeNet
from tools.common_tools import set_seed
from tools.my_dataset import RMBdataset

In [None]:
set_seed()
rmb_label = {'1':0,'100':1}

max_epoch = 10
batch_size = 4
lr = 0.01
log_interval = 40
val_interval = 1

In [None]:
split_dir = os.path.abspath(os.path.join(base_dir,'..','data','RMB_split_new'))
if not os.path.exists(split_dir):
    raise Exception('no data')
train_dir = os.path.join(split_dir,'train')
valid_dir = os.path.join(split_dir,'valid')

norm_mean = [0.485,0.456,0.406]
norm_std = [0.299,0.224,0.225]

train_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomCrop(32,padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean,norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean,norm_std),
])

train_data = RMBdataset(train_dir,train_transform)
valid_data = RMBdataset(valid_dir,valid_transform)

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
valid_loader = DataLoader(valid_data,batch_size = batch_size)

In [None]:
len(valid_data.data_info)

In [None]:
net = LeNet(classes=2)
net.initialize_weights()
net

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)

In [None]:
train_curve = list()
valid_curve = list()

In [None]:
for epoch in range(max_epoch):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()

    for i,data in enumerate(train_loader):

        inputs,labels = data
        outputs = net(inputs)

        optimizer.zero_grad()
        loss = criterion(outputs,labels)
        loss.backward()

        optimizer.step()

        _,predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        loss_mean =+ loss.item()
        train_curve.append(loss.item())
        if (i+1)%log_interval == 0:
            loss_mean = loss_mean/log_interval
            print('train : epoch{} itear{} loss{} acc{}'.format(epoch,i+1,loss_mean,correct/total))
            loss_mean = 0

        scheduler.step()

        if (epoch+1)%val_interval == 0:

            correct_val = 0.
            total_val = 0.
            loss_val = 0.
            net.eval()
            with torch.no_grad():
                for j,data in enumerate(valid_loader):
                    inputs,label = data
                    outputs = net(inputs)
                    loss = criterion(outputs,labels)

                    _,predicted = torch.max(outputs.data,1)
                    total_val += label.size(0)
                    correct_val += (predicted == labels).squeeze().sum().numpy()

                    loss_val += loss.item()

                loss_val_epoch = loss_val / len(valid_loader)
                valid_curve.append(loss_val_epoch)

                print('valid: epoch : {}  itear {}  loss {}  acc {}'.format(epoch,j+1,loss_val_epoch,correct_val/total_val))
        


