In [2]:
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import import_ipynb
from torchvision import datasets

from model.resnet import ResNet18

import json
import os

from irg import IRG

importing Jupyter notebook from irg.ipynb


In [3]:
import Ranger
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import timm
from torchinfo import summary #打印网络模型各层信息的库

In [4]:
def get_dataloader(train_dir, val_dir, batch_size):
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])

    dataset_train = datasets.ImageFolder(train_dir, transform=transform)
    dataset_test = datasets.ImageFolder(val_dir, transform=transform_test)
    with open('class.txt', 'w', encoding='utf-8') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size = batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size = batch_size, shuffle=False)
    return train_loader,test_loader

In [25]:
def train(model, teacher_model, criterionKD, criterionCls, train_dataloader, test_dataloader, 
          device, batch_size, num_epoch, learning_rate, optim, init=False, scheduler_type='Cosine'):
    def init_xavier(m):
        if type(m) == nn.Linear:
            nn.init.xavier_normal_(m.weight)
            
    if init:
        model.apply(init_xavier)
        
    print(f'Model is on device: {next(model.parameters()).device}')
    print("teacher is on device:{}".format(next(teacher_model.parameters()).device))
    
    print('training on:', device)
    model.to(device)
    
    print("student_model is on device:{}".format(next(model.parameters()).device))
    print("teacher_model is on device:{}".format(next(teacher_model.parameters()).device))
    
    if optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    if optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    if optim == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    if scheduler_type == 'Cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20,eta_min=1e-9)
    
    train_losses = []
    train_acces = []
    eval_acces = []
    best_acc = 0.0
    
    for epoch in range(1, num_epoch + 1):
        model.train()
        total_correct = 0
        total_loss = 0
        for (data, target) in tqdm(train_dataloader, desc = 'trainning'):
            data, target =data.to(device), target.to(device)
            l1_out_s,l2_out_s,l3_out_s,l4_out_s,fea_s, out_s = model(data)
            cls_loss = criterionCls(out_s, target)
            l1_out_t,l2_out_t,l3_out_t,l4_out_t,fea_t, out_t = teacher_model(data)
            kd_loss = criterionKD([l3_out_s, l4_out_s, fea_s, out_s],
                              [l3_out_t.detach(),
                               l4_out_t.detach(),
                               fea_t.detach(),
                               out_t.detach()]) * lambda_kd
            loss = cls_loss + kd_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.data.item()
            _, pred = torch.max(out_s.data, 1)
            total_correct += torch.sum(pred == target).item()
        avg_loss = total_loss / len(train_dataloader)
        accuracy = total_correct / len(train_dataloader.dataset)
        print('Epoch: {}, train_loss: {:.4f}, train_accuracy: {}/{} ({:.0f}%)'.format(epoch,
            avg_loss, total_correct, len(train_dataloader.dataset), 100 * accuracy))
        train_acces.append(accuracy)
        train_losses.append(avg_loss)
        scheduler.step()
        
        model.eval()
        total_correct = 0
        total_loss = 0
        with torch.no_grad():
            for data, target in test_dataloader:
                data, target = data.to(device), target.to(device)
                _, _, _, _, _, out = model(data)
                loss = criterionCls(out, target)
                _,pred = torch.max(out.data, 1)
                total_correct += torch.sum(pred == target).item()
                total_loss += loss.data.item()
            avg_loss = total_loss / len(test_dataloader)
            accuracy = total_correct / len(test_dataloader.dataset)
            if accuracy > best_acc:
                best_acc = accuracy
                torch.save(model.state_dict(), file_dir + '/' + str(accuracy) + 'best.pth')
            eval_acces.append(accuracy)
            print('Epoch: {}, test_loss: {:.4f}, test_accuracy: {}/{} ({:.0f}%)\n'.format(epoch,
            avg_loss, total_correct, len(test_dataloader.dataset), 100 * accuracy))

In [13]:
if __name__ == '__main__':
    file_dir = 'kd_model'
    if not os.path.isdir(file_dir):
        os.makedirs(file_dir)

In [20]:
    learning_rate = 1e-4
    optimizer = 'adam'
    batch_size = 64
    num_epoch = 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    w_irg_vert=0.1
    w_irg_edge=5.0
    w_irg_tran=5.0
    lambda_kd=1.0

In [15]:
    train_dir = './data/train'
    test_dir = './data/val'
    train_dataloader, test_dataloader = get_dataloader(train_dir, test_dir, batch_size)

In [16]:
    model = ResNet18()
    model.fc = nn.Linear(model.fc.in_features, 12)
    teacher_model = torch.load('./teacherNet/best.pth')
    teacher_model.eval()
    criterionKD = IRG(w_irg_vert, w_irg_edge, w_irg_tran)
    criterionCls = nn.CrossEntropyLoss()

In [17]:
device

device(type='cuda')

In [26]:
    train(model,teacher_model,criterionKD,criterionCls,train_dataloader,test_dataloader,
         device,batch_size, num_epoch, learning_rate, optim)

Model is on device: cuda:0
teacher is on device:cuda:0
training on: cuda
student_model is on device:cuda:0
teacher_model is on device:cuda:0


trainning: 100%|██████████| 62/62 [01:44<00:00,  1.68s/it]


Epoch: 1, train_loss: 1.7463, train_accuracy: 3062/3961 (77%)
Epoch: 1, test_loss: 1.7152, test_accuracy: 881/1698 (52%)



trainning: 100%|██████████| 62/62 [01:43<00:00,  1.68s/it]


Epoch: 2, train_loss: 1.4554, train_accuracy: 3233/3961 (82%)
Epoch: 2, test_loss: 0.5655, test_accuracy: 1380/1698 (81%)



trainning: 100%|██████████| 62/62 [01:44<00:00,  1.68s/it]


Epoch: 3, train_loss: 1.2160, train_accuracy: 3370/3961 (85%)
Epoch: 3, test_loss: 0.6009, test_accuracy: 1352/1698 (80%)



trainning: 100%|██████████| 62/62 [01:43<00:00,  1.67s/it]


Epoch: 4, train_loss: 1.1676, train_accuracy: 3403/3961 (86%)
Epoch: 4, test_loss: 0.7546, test_accuracy: 1294/1698 (76%)



trainning: 100%|██████████| 62/62 [01:43<00:00,  1.68s/it]


Epoch: 5, train_loss: 1.0043, train_accuracy: 3489/3961 (88%)
Epoch: 5, test_loss: 0.4974, test_accuracy: 1427/1698 (84%)



trainning: 100%|██████████| 62/62 [01:44<00:00,  1.69s/it]


Epoch: 6, train_loss: 0.9380, train_accuracy: 3525/3961 (89%)
Epoch: 6, test_loss: 0.7010, test_accuracy: 1412/1698 (83%)



trainning: 100%|██████████| 62/62 [01:44<00:00,  1.69s/it]


Epoch: 7, train_loss: 0.7919, train_accuracy: 3569/3961 (90%)
Epoch: 7, test_loss: 0.5603, test_accuracy: 1398/1698 (82%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 8, train_loss: 0.7166, train_accuracy: 3638/3961 (92%)
Epoch: 8, test_loss: 0.3402, test_accuracy: 1515/1698 (89%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 9, train_loss: 0.7022, train_accuracy: 3670/3961 (93%)
Epoch: 9, test_loss: 0.3825, test_accuracy: 1490/1698 (88%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 10, train_loss: 0.6185, train_accuracy: 3701/3961 (93%)
Epoch: 10, test_loss: 0.4795, test_accuracy: 1438/1698 (85%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.71s/it]


Epoch: 11, train_loss: 0.5558, train_accuracy: 3726/3961 (94%)
Epoch: 11, test_loss: 0.5359, test_accuracy: 1472/1698 (87%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 12, train_loss: 0.4833, train_accuracy: 3776/3961 (95%)
Epoch: 12, test_loss: 0.2876, test_accuracy: 1552/1698 (91%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 13, train_loss: 0.4624, train_accuracy: 3792/3961 (96%)
Epoch: 13, test_loss: 0.3158, test_accuracy: 1538/1698 (91%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 14, train_loss: 0.4003, train_accuracy: 3831/3961 (97%)
Epoch: 14, test_loss: 0.3176, test_accuracy: 1550/1698 (91%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 15, train_loss: 0.3736, train_accuracy: 3841/3961 (97%)
Epoch: 15, test_loss: 0.3173, test_accuracy: 1544/1698 (91%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.71s/it]


Epoch: 16, train_loss: 0.3633, train_accuracy: 3847/3961 (97%)
Epoch: 16, test_loss: 0.2681, test_accuracy: 1571/1698 (93%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.70s/it]


Epoch: 17, train_loss: 0.3443, train_accuracy: 3857/3961 (97%)
Epoch: 17, test_loss: 0.3026, test_accuracy: 1547/1698 (91%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 18, train_loss: 0.3244, train_accuracy: 3866/3961 (98%)
Epoch: 18, test_loss: 0.2631, test_accuracy: 1573/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 19, train_loss: 0.3136, train_accuracy: 3860/3961 (97%)
Epoch: 19, test_loss: 0.2716, test_accuracy: 1575/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 20, train_loss: 0.3040, train_accuracy: 3878/3961 (98%)
Epoch: 20, test_loss: 0.2637, test_accuracy: 1579/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.73s/it]


Epoch: 21, train_loss: 0.3031, train_accuracy: 3879/3961 (98%)
Epoch: 21, test_loss: 0.2632, test_accuracy: 1583/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 22, train_loss: 0.3081, train_accuracy: 3883/3961 (98%)
Epoch: 22, test_loss: 0.2639, test_accuracy: 1579/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 23, train_loss: 0.3100, train_accuracy: 3874/3961 (98%)
Epoch: 23, test_loss: 0.2676, test_accuracy: 1577/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 24, train_loss: 0.2993, train_accuracy: 3876/3961 (98%)
Epoch: 24, test_loss: 0.2736, test_accuracy: 1571/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 25, train_loss: 0.3068, train_accuracy: 3884/3961 (98%)
Epoch: 25, test_loss: 0.2951, test_accuracy: 1553/1698 (91%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 26, train_loss: 0.3181, train_accuracy: 3873/3961 (98%)
Epoch: 26, test_loss: 0.3203, test_accuracy: 1567/1698 (92%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 27, train_loss: 0.3652, train_accuracy: 3841/3961 (97%)
Epoch: 27, test_loss: 0.3260, test_accuracy: 1542/1698 (91%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.76s/it]


Epoch: 28, train_loss: 0.3491, train_accuracy: 3840/3961 (97%)
Epoch: 28, test_loss: 0.4284, test_accuracy: 1475/1698 (87%)



trainning: 100%|██████████| 62/62 [01:54<00:00,  1.84s/it]


Epoch: 29, train_loss: 0.3450, train_accuracy: 3845/3961 (97%)
Epoch: 29, test_loss: 0.3469, test_accuracy: 1537/1698 (91%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 30, train_loss: 0.4007, train_accuracy: 3835/3961 (97%)
Epoch: 30, test_loss: 0.3990, test_accuracy: 1506/1698 (89%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 31, train_loss: 0.3887, train_accuracy: 3829/3961 (97%)
Epoch: 31, test_loss: 0.3369, test_accuracy: 1540/1698 (91%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 32, train_loss: 0.4406, train_accuracy: 3790/3961 (96%)
Epoch: 32, test_loss: 0.4779, test_accuracy: 1436/1698 (85%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 33, train_loss: 0.4299, train_accuracy: 3794/3961 (96%)
Epoch: 33, test_loss: 0.4025, test_accuracy: 1503/1698 (89%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 34, train_loss: 0.4407, train_accuracy: 3798/3961 (96%)
Epoch: 34, test_loss: 0.3988, test_accuracy: 1526/1698 (90%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 35, train_loss: 0.4674, train_accuracy: 3774/3961 (95%)
Epoch: 35, test_loss: 0.5797, test_accuracy: 1428/1698 (84%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 36, train_loss: 0.5039, train_accuracy: 3755/3961 (95%)
Epoch: 36, test_loss: 0.5116, test_accuracy: 1438/1698 (85%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.74s/it]


Epoch: 37, train_loss: 0.4155, train_accuracy: 3834/3961 (97%)
Epoch: 37, test_loss: 1.6445, test_accuracy: 1124/1698 (66%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 38, train_loss: 0.5236, train_accuracy: 3775/3961 (95%)
Epoch: 38, test_loss: 0.4471, test_accuracy: 1498/1698 (88%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 39, train_loss: 0.3969, train_accuracy: 3821/3961 (96%)
Epoch: 39, test_loss: 1.5731, test_accuracy: 1120/1698 (66%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 40, train_loss: 0.4184, train_accuracy: 3804/3961 (96%)
Epoch: 40, test_loss: 0.7256, test_accuracy: 1362/1698 (80%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 41, train_loss: 0.4320, train_accuracy: 3818/3961 (96%)
Epoch: 41, test_loss: 0.8370, test_accuracy: 1404/1698 (83%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 42, train_loss: 0.4048, train_accuracy: 3832/3961 (97%)
Epoch: 42, test_loss: 0.3589, test_accuracy: 1523/1698 (90%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 43, train_loss: 0.3280, train_accuracy: 3874/3961 (98%)
Epoch: 43, test_loss: 0.4823, test_accuracy: 1471/1698 (87%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 44, train_loss: 0.3174, train_accuracy: 3868/3961 (98%)
Epoch: 44, test_loss: 1.7035, test_accuracy: 1350/1698 (80%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 45, train_loss: 0.2853, train_accuracy: 3889/3961 (98%)
Epoch: 45, test_loss: 0.3743, test_accuracy: 1547/1698 (91%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 46, train_loss: 0.2458, train_accuracy: 3912/3961 (99%)
Epoch: 46, test_loss: 0.3068, test_accuracy: 1572/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 47, train_loss: 0.2303, train_accuracy: 3917/3961 (99%)
Epoch: 47, test_loss: 0.3564, test_accuracy: 1549/1698 (91%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 48, train_loss: 0.2017, train_accuracy: 3919/3961 (99%)
Epoch: 48, test_loss: 0.3399, test_accuracy: 1550/1698 (91%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.74s/it]


Epoch: 49, train_loss: 0.1925, train_accuracy: 3939/3961 (99%)
Epoch: 49, test_loss: 0.3907, test_accuracy: 1527/1698 (90%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 50, train_loss: 0.1851, train_accuracy: 3940/3961 (99%)
Epoch: 50, test_loss: 1.5306, test_accuracy: 1418/1698 (84%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.74s/it]


Epoch: 51, train_loss: 0.1788, train_accuracy: 3943/3961 (100%)
Epoch: 51, test_loss: 0.3257, test_accuracy: 1571/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 52, train_loss: 0.1808, train_accuracy: 3946/3961 (100%)
Epoch: 52, test_loss: 0.3203, test_accuracy: 1570/1698 (92%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 53, train_loss: 0.1659, train_accuracy: 3944/3961 (100%)
Epoch: 53, test_loss: 0.2948, test_accuracy: 1588/1698 (94%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.71s/it]


Epoch: 54, train_loss: 0.1577, train_accuracy: 3940/3961 (99%)
Epoch: 54, test_loss: 0.2850, test_accuracy: 1580/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 55, train_loss: 0.1513, train_accuracy: 3947/3961 (100%)
Epoch: 55, test_loss: 0.2843, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 56, train_loss: 0.1415, train_accuracy: 3950/3961 (100%)
Epoch: 56, test_loss: 0.2834, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 57, train_loss: 0.1324, train_accuracy: 3952/3961 (100%)
Epoch: 57, test_loss: 0.2828, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 58, train_loss: 0.1286, train_accuracy: 3948/3961 (100%)
Epoch: 58, test_loss: 0.2908, test_accuracy: 1594/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 59, train_loss: 0.1341, train_accuracy: 3955/3961 (100%)
Epoch: 59, test_loss: 0.2863, test_accuracy: 1588/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.71s/it]


Epoch: 60, train_loss: 0.1312, train_accuracy: 3954/3961 (100%)
Epoch: 60, test_loss: 0.2830, test_accuracy: 1590/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 61, train_loss: 0.1316, train_accuracy: 3952/3961 (100%)
Epoch: 61, test_loss: 0.2876, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:49<00:00,  1.76s/it]


Epoch: 62, train_loss: 0.1309, train_accuracy: 3954/3961 (100%)
Epoch: 62, test_loss: 0.2870, test_accuracy: 1590/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 63, train_loss: 0.1287, train_accuracy: 3958/3961 (100%)
Epoch: 63, test_loss: 0.2870, test_accuracy: 1584/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 64, train_loss: 0.1331, train_accuracy: 3951/3961 (100%)
Epoch: 64, test_loss: 0.2859, test_accuracy: 1588/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 65, train_loss: 0.1301, train_accuracy: 3953/3961 (100%)
Epoch: 65, test_loss: 0.2914, test_accuracy: 1591/1698 (94%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 66, train_loss: 0.1269, train_accuracy: 3957/3961 (100%)
Epoch: 66, test_loss: 0.2962, test_accuracy: 1591/1698 (94%)



trainning: 100%|██████████| 62/62 [01:44<00:00,  1.68s/it]


Epoch: 67, train_loss: 0.1370, train_accuracy: 3948/3961 (100%)
Epoch: 67, test_loss: 0.2956, test_accuracy: 1579/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 68, train_loss: 0.1355, train_accuracy: 3951/3961 (100%)
Epoch: 68, test_loss: 0.3191, test_accuracy: 1580/1698 (93%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 69, train_loss: 0.1573, train_accuracy: 3949/3961 (100%)
Epoch: 69, test_loss: 0.3144, test_accuracy: 1585/1698 (93%)



trainning: 100%|██████████| 62/62 [01:45<00:00,  1.71s/it]


Epoch: 70, train_loss: 0.1610, train_accuracy: 3944/3961 (100%)
Epoch: 70, test_loss: 0.2995, test_accuracy: 1577/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 71, train_loss: 0.1606, train_accuracy: 3944/3961 (100%)
Epoch: 71, test_loss: 0.3491, test_accuracy: 1571/1698 (93%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.74s/it]


Epoch: 72, train_loss: 0.1878, train_accuracy: 3935/3961 (99%)
Epoch: 72, test_loss: 0.3518, test_accuracy: 1550/1698 (91%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 73, train_loss: 0.2330, train_accuracy: 3919/3961 (99%)
Epoch: 73, test_loss: 0.4306, test_accuracy: 1498/1698 (88%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 74, train_loss: 0.2604, train_accuracy: 3909/3961 (99%)
Epoch: 74, test_loss: 0.4441, test_accuracy: 1513/1698 (89%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 75, train_loss: 0.2663, train_accuracy: 3895/3961 (98%)
Epoch: 75, test_loss: 0.7804, test_accuracy: 1434/1698 (84%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 76, train_loss: 0.3021, train_accuracy: 3886/3961 (98%)
Epoch: 76, test_loss: 0.3742, test_accuracy: 1536/1698 (90%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.73s/it]


Epoch: 77, train_loss: 0.2839, train_accuracy: 3887/3961 (98%)
Epoch: 77, test_loss: 1.2898, test_accuracy: 1329/1698 (78%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 78, train_loss: 0.3234, train_accuracy: 3869/3961 (98%)
Epoch: 78, test_loss: 0.4299, test_accuracy: 1530/1698 (90%)



trainning: 100%|██████████| 62/62 [01:46<00:00,  1.72s/it]


Epoch: 79, train_loss: 0.2590, train_accuracy: 3904/3961 (99%)
Epoch: 79, test_loss: 0.3415, test_accuracy: 1555/1698 (92%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 80, train_loss: 0.2435, train_accuracy: 3923/3961 (99%)
Epoch: 80, test_loss: 0.5956, test_accuracy: 1473/1698 (87%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 81, train_loss: 0.2114, train_accuracy: 3924/3961 (99%)
Epoch: 81, test_loss: 0.8374, test_accuracy: 1442/1698 (85%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 82, train_loss: 0.2355, train_accuracy: 3917/3961 (99%)
Epoch: 82, test_loss: 0.5513, test_accuracy: 1463/1698 (86%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 83, train_loss: 0.3250, train_accuracy: 3881/3961 (98%)
Epoch: 83, test_loss: 0.4443, test_accuracy: 1508/1698 (89%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 84, train_loss: 0.2425, train_accuracy: 3912/3961 (99%)
Epoch: 84, test_loss: 0.4011, test_accuracy: 1530/1698 (90%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 85, train_loss: 0.2124, train_accuracy: 3926/3961 (99%)
Epoch: 85, test_loss: 0.3786, test_accuracy: 1532/1698 (90%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 86, train_loss: 0.1998, train_accuracy: 3938/3961 (99%)
Epoch: 86, test_loss: 0.4995, test_accuracy: 1510/1698 (89%)



trainning: 100%|██████████| 62/62 [01:50<00:00,  1.79s/it]


Epoch: 87, train_loss: 0.1829, train_accuracy: 3937/3961 (99%)
Epoch: 87, test_loss: 0.3052, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 88, train_loss: 0.1612, train_accuracy: 3950/3961 (100%)
Epoch: 88, test_loss: 0.3869, test_accuracy: 1564/1698 (92%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.74s/it]


Epoch: 89, train_loss: 0.1395, train_accuracy: 3955/3961 (100%)
Epoch: 89, test_loss: 0.3410, test_accuracy: 1563/1698 (92%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 90, train_loss: 0.1621, train_accuracy: 3959/3961 (100%)
Epoch: 90, test_loss: 0.2998, test_accuracy: 1584/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 91, train_loss: 0.1317, train_accuracy: 3957/3961 (100%)
Epoch: 91, test_loss: 0.2838, test_accuracy: 1580/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 92, train_loss: 0.1267, train_accuracy: 3958/3961 (100%)
Epoch: 92, test_loss: 0.3072, test_accuracy: 1587/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 93, train_loss: 0.1263, train_accuracy: 3956/3961 (100%)
Epoch: 93, test_loss: 0.2896, test_accuracy: 1579/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 94, train_loss: 0.1175, train_accuracy: 3955/3961 (100%)
Epoch: 94, test_loss: 0.3035, test_accuracy: 1583/1698 (93%)



trainning: 100%|██████████| 62/62 [01:50<00:00,  1.78s/it]


Epoch: 95, train_loss: 0.1131, train_accuracy: 3955/3961 (100%)
Epoch: 95, test_loss: 0.2982, test_accuracy: 1580/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 96, train_loss: 0.1134, train_accuracy: 3960/3961 (100%)
Epoch: 96, test_loss: 0.2891, test_accuracy: 1587/1698 (93%)



trainning: 100%|██████████| 62/62 [01:48<00:00,  1.75s/it]


Epoch: 97, train_loss: 0.1080, train_accuracy: 3957/3961 (100%)
Epoch: 97, test_loss: 0.2930, test_accuracy: 1586/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 98, train_loss: 0.1024, train_accuracy: 3959/3961 (100%)
Epoch: 98, test_loss: 0.2881, test_accuracy: 1584/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.74s/it]


Epoch: 99, train_loss: 0.1060, train_accuracy: 3961/3961 (100%)
Epoch: 99, test_loss: 0.2897, test_accuracy: 1583/1698 (93%)



trainning: 100%|██████████| 62/62 [01:47<00:00,  1.73s/it]


Epoch: 100, train_loss: 0.1047, train_accuracy: 3957/3961 (100%)
Epoch: 100, test_loss: 0.2897, test_accuracy: 1583/1698 (93%)



In [27]:
help(transforms.Pad)

Help on class Pad in module torchvision.transforms.transforms:

class Pad(torch.nn.modules.module.Module)
 |  Pad(padding, fill=0, padding_mode='constant')
 |  
 |  Pad the given image on all sides with the given "pad" value.
 |  If the image is torch Tensor, it is expected
 |  to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
 |  at most 3 leading dimensions for mode edge,
 |  and an arbitrary number of leading dimensions for mode constant
 |  
 |  Args:
 |      padding (int or sequence): Padding on each border. If a single int is provided this
 |          is used to pad all borders. If sequence of length 2 is provided this is the padding
 |          on left/right and top/bottom respectively. If a sequence of length 4 is provided
 |          this is the padding for the left, top, right and bottom borders respectively.
 |  
 |          .. note::
 |              In torchscript mode padding as single int is not supported, use a sequen