In [1]:
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
import numpy as np
from einops import rearrange
import torch
from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)

device = torch.device('cuda:0' if USE_CUDA else 'cpu')
print('학습을 진행하는 기기:', device)

True
학습을 진행하는 기기: cuda:0


In [2]:
with open('training_set.dat', "rb") as training_file:
    x = pickle.load(training_file)

In [3]:
with open('test_set.dat', "rb") as training_file:
    y = pickle.load(training_file)

In [4]:
x[0].shape

(60, 30, 224, 224, 3)

In [5]:
x[1]

array([0, 2, 1, 7, 1, 7, 5, 3, 1, 2, 6, 5, 3, 1, 0, 7, 8, 1, 8, 4, 9, 0,
       5, 2, 6, 4, 6, 5, 0, 2, 5, 9, 4, 3, 7, 0, 9, 6, 6, 3, 3, 2, 7, 5,
       7, 6, 4, 8, 4, 3, 9, 8, 4, 9, 8, 9, 1, 8, 2, 0])

In [6]:
def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    for i in range(len(labels)):
        one_hot_labels[i, labels[i]] = 1
    return one_hot_labels

In [7]:
temp_x = one_hot_encode(x[1], 10)
temp_y = one_hot_encode(y[1], 10)

In [8]:
print('temp x = ', temp_x)
print('temp y = ', temp_y)

temp x =  [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]


In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.hub import load_state_dict_from_url
import torchvision
from functools import partial
from collections import OrderedDict
import math

import os,inspect,sys

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0,currentdir)

def convert_relu_to_swish(model):
        for child_name, child in model.named_children():
            if isinstance(child, nn.ReLU):
                setattr(model, child_name, nn.SiLU(True))
                # setattr(model, child_name, Swish())
            else:
                convert_relu_to_swish(child)

class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mul_(torch.sigmoid(x))

class r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)
        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)
        out = self.softmax(out)
        return out

class flow_r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(flow_r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)

        model.stem[0] = nn.Conv3d(2, 45, kernel_size=(1, 7, 7),
                            stride=(1, 2, 2), padding=(0, 3, 3),
                            bias=False)

        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
    def forward(self, x):
        # print(x.size())
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)

        return out

In [10]:
model = r2plus1d_18(num_classes = 10).to(device)



In [11]:
# from torchinfo import summary

# summary(model, input_size = (4,3,15,224,224), col_names = ['input_size','output_size','num_params'], verbose=1)

In [12]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [13]:
transform = transforms.Compose([
    # transforms.ToPILImage(),
    # transforms.Resize((224,224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                      std=[0.229, 0.224, 0.225]),
])


In [14]:
class SignLanGuageDataset(Dataset):
    def __init__(self,imagedata,tagdata,transform):
        self.imagedata=imagedata
        self.tagdata=tagdata
        self.transform=transform

    def __len__(self):
        return len(self.imagedata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_data=(self.imagedata[idx])
        image_data=torch.FloatTensor(image_data)
        label=self.tagdata[idx]
        label = torch.FloatTensor(label)
        return image_data,label

In [15]:
batchsz = 4
num_workerssz = 4
epochs = 120

In [16]:
train_dataset = SignLanGuageDataset(imagedata=x[0],tagdata=temp_x,transform=transform)
valid_dataset = SignLanGuageDataset(imagedata=y[0],tagdata=temp_y,transform=transform)

In [17]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)

In [18]:
with torch.cuda.device(0):
    
    for epoch in range(epochs):
        train_avg_loss = 0
        val_avg_loss = 0
        
        model.train()
        for data, target in tqdm(train_dataloader):
            data = rearrange(data, 'b d h w c -> b c d h w')
            data = data.to(device)
            
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_avg_loss += loss
        train_avg_loss = train_avg_loss/len(train_dataloader)
        print('Epoch = {}, train_loss = {}'.format(epoch+1, train_avg_loss))
        with torch.no_grad(): # valid
            model.eval()
            for data, target in tqdm(valid_dataloader):
                data = rearrange(data, 'b d h w c -> b c d h w')
                data = data.to(device)
                
                target = target.to(device)
                
                hypothesis = model(data)
                val_loss = criterion(hypothesis, target)
                val_avg_loss += val_loss
                
            val_avg_loss = val_avg_loss/len(valid_dataloader)
        print('Epoch = {}, val_loss = {}'.format(epoch+1, val_avg_loss))
        print('\n')

  return F.conv3d(
100%|██████████| 15/15 [00:07<00:00,  2.05it/s]


Epoch = 1, train_loss = 2.318044900894165


100%|██████████| 5/5 [00:00<00:00,  7.08it/s]


Epoch = 1, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 2, train_loss = 2.3037967681884766


100%|██████████| 5/5 [00:00<00:00,  7.16it/s]


Epoch = 2, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 3, train_loss = 2.3001644611358643


100%|██████████| 5/5 [00:00<00:00,  7.50it/s]


Epoch = 3, val_loss = 2.3495442867279053




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 4, train_loss = 2.2898905277252197


100%|██████████| 5/5 [00:00<00:00,  7.74it/s]


Epoch = 4, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 5, train_loss = 2.301513910293579


100%|██████████| 5/5 [00:00<00:00,  7.73it/s]


Epoch = 5, val_loss = 2.361923933029175




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 6, train_loss = 2.286975383758545


100%|██████████| 5/5 [00:00<00:00,  7.58it/s]


Epoch = 6, val_loss = 2.361147403717041




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 7, train_loss = 2.302485704421997


100%|██████████| 5/5 [00:00<00:00,  7.53it/s]


Epoch = 7, val_loss = 2.3594071865081787




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 8, train_loss = 2.285885810852051


100%|██████████| 5/5 [00:00<00:00,  7.27it/s]


Epoch = 8, val_loss = 2.3036534786224365




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 9, train_loss = 2.2766780853271484


100%|██████████| 5/5 [00:00<00:00,  7.16it/s]


Epoch = 9, val_loss = 2.3594796657562256




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 10, train_loss = 2.28775691986084


100%|██████████| 5/5 [00:00<00:00,  7.82it/s]


Epoch = 10, val_loss = 2.3517279624938965




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 11, train_loss = 2.310633897781372


100%|██████████| 5/5 [00:00<00:00,  7.59it/s]


Epoch = 11, val_loss = 2.359607458114624




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 12, train_loss = 2.293069839477539


100%|██████████| 5/5 [00:00<00:00,  7.48it/s]


Epoch = 12, val_loss = 2.3080568313598633




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 13, train_loss = 2.283794403076172


100%|██████████| 5/5 [00:00<00:00,  7.18it/s]


Epoch = 13, val_loss = 2.303455114364624




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 14, train_loss = 2.274831533432007


100%|██████████| 5/5 [00:00<00:00,  7.45it/s]


Epoch = 14, val_loss = 2.3607687950134277




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 15, train_loss = 2.284358263015747


100%|██████████| 5/5 [00:00<00:00,  6.99it/s]


Epoch = 15, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 16, train_loss = 2.262082815170288


100%|██████████| 5/5 [00:00<00:00,  7.46it/s]


Epoch = 16, val_loss = 2.360926389694214




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 17, train_loss = 2.2726144790649414


100%|██████████| 5/5 [00:00<00:00,  7.59it/s]


Epoch = 17, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 18, train_loss = 2.264068126678467


100%|██████████| 5/5 [00:00<00:00,  7.81it/s]


Epoch = 18, val_loss = 2.2900853157043457




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 19, train_loss = 2.245004892349243


100%|██████████| 5/5 [00:00<00:00,  7.91it/s]


Epoch = 19, val_loss = 2.3611514568328857




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 20, train_loss = 2.2076199054718018


100%|██████████| 5/5 [00:00<00:00,  7.05it/s]


Epoch = 20, val_loss = 2.3395659923553467




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 21, train_loss = 2.1844615936279297


100%|██████████| 5/5 [00:00<00:00,  7.48it/s]


Epoch = 21, val_loss = 2.360327959060669




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 22, train_loss = 2.1980955600738525


100%|██████████| 5/5 [00:00<00:00,  6.97it/s]


Epoch = 22, val_loss = 2.3610198497772217




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 23, train_loss = 2.1816205978393555


100%|██████████| 5/5 [00:00<00:00,  7.15it/s]


Epoch = 23, val_loss = 2.361149311065674




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 24, train_loss = 2.162656545639038


100%|██████████| 5/5 [00:00<00:00,  7.13it/s]


Epoch = 24, val_loss = 2.2029712200164795




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 25, train_loss = 2.1116559505462646


100%|██████████| 5/5 [00:00<00:00,  7.28it/s]


Epoch = 25, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 26, train_loss = 2.0978477001190186


100%|██████████| 5/5 [00:00<00:00,  7.34it/s]


Epoch = 26, val_loss = 2.275020122528076




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 27, train_loss = 2.0812978744506836


100%|██████████| 5/5 [00:00<00:00,  7.07it/s]


Epoch = 27, val_loss = 2.359679937362671




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 28, train_loss = 2.058049201965332


100%|██████████| 5/5 [00:00<00:00,  7.61it/s]


Epoch = 28, val_loss = 2.3361613750457764




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 29, train_loss = 2.0916619300842285


100%|██████████| 5/5 [00:00<00:00,  7.85it/s]


Epoch = 29, val_loss = 2.312113046646118




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 30, train_loss = 2.128340482711792


100%|██████████| 5/5 [00:00<00:00,  7.74it/s]


Epoch = 30, val_loss = 2.196544647216797




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 31, train_loss = 2.0814144611358643


100%|██████████| 5/5 [00:00<00:00,  7.54it/s]


Epoch = 31, val_loss = 2.3611502647399902




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 32, train_loss = 2.0687315464019775


100%|██████████| 5/5 [00:00<00:00,  7.83it/s]


Epoch = 32, val_loss = 2.317056179046631




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 33, train_loss = 2.0574750900268555


100%|██████████| 5/5 [00:00<00:00,  7.72it/s]


Epoch = 33, val_loss = 2.1808712482452393




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 34, train_loss = 2.0380754470825195


100%|██████████| 5/5 [00:00<00:00,  7.75it/s]


Epoch = 34, val_loss = 2.3608248233795166




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 35, train_loss = 1.9867807626724243


100%|██████████| 5/5 [00:00<00:00,  7.31it/s]


Epoch = 35, val_loss = 2.2120347023010254




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 36, train_loss = 1.9312305450439453


100%|██████████| 5/5 [00:00<00:00,  7.61it/s]


Epoch = 36, val_loss = 2.2074780464172363




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 37, train_loss = 1.9111558198928833


100%|██████████| 5/5 [00:00<00:00,  7.20it/s]


Epoch = 37, val_loss = 2.2774658203125




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 38, train_loss = 1.9355226755142212


100%|██████████| 5/5 [00:00<00:00,  7.24it/s]


Epoch = 38, val_loss = 2.1955509185791016




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 39, train_loss = 1.9210553169250488


100%|██████████| 5/5 [00:00<00:00,  7.15it/s]


Epoch = 39, val_loss = 2.1074211597442627




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 40, train_loss = 1.9250519275665283


100%|██████████| 5/5 [00:00<00:00,  7.27it/s]


Epoch = 40, val_loss = 2.200835943222046




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 41, train_loss = 1.8772529363632202


100%|██████████| 5/5 [00:00<00:00,  7.92it/s]


Epoch = 41, val_loss = 2.099900007247925




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 42, train_loss = 1.871627688407898


100%|██████████| 5/5 [00:00<00:00,  7.73it/s]


Epoch = 42, val_loss = 2.092672109603882




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 43, train_loss = 1.885323405265808


100%|██████████| 5/5 [00:00<00:00,  7.41it/s]


Epoch = 43, val_loss = 2.219834089279175




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 44, train_loss = 1.921026587486267


100%|██████████| 5/5 [00:00<00:00,  7.91it/s]


Epoch = 44, val_loss = 2.2849392890930176




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 45, train_loss = 1.9068198204040527


100%|██████████| 5/5 [00:00<00:00,  7.71it/s]


Epoch = 45, val_loss = 2.358245611190796




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 46, train_loss = 1.9382137060165405


100%|██████████| 5/5 [00:00<00:00,  7.00it/s]


Epoch = 46, val_loss = 2.098756790161133




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 47, train_loss = 1.879426121711731


100%|██████████| 5/5 [00:00<00:00,  7.58it/s]


Epoch = 47, val_loss = 2.119056463241577




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 48, train_loss = 1.8775219917297363


100%|██████████| 5/5 [00:00<00:00,  7.83it/s]


Epoch = 48, val_loss = 2.1547610759735107




100%|██████████| 15/15 [00:06<00:00,  2.48it/s]


Epoch = 49, train_loss = 1.8560305833816528


100%|██████████| 5/5 [00:00<00:00,  6.87it/s]


Epoch = 49, val_loss = 2.1327342987060547




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 50, train_loss = 1.8350216150283813


100%|██████████| 5/5 [00:00<00:00,  7.42it/s]


Epoch = 50, val_loss = 2.1554949283599854




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 51, train_loss = 1.8368555307388306


100%|██████████| 5/5 [00:00<00:00,  7.38it/s]


Epoch = 51, val_loss = 2.194803237915039




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 52, train_loss = 1.8254034519195557


100%|██████████| 5/5 [00:00<00:00,  7.58it/s]


Epoch = 52, val_loss = 2.1033992767333984




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 53, train_loss = 1.8493458032608032


100%|██████████| 5/5 [00:00<00:00,  7.45it/s]


Epoch = 53, val_loss = 2.1802680492401123




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 54, train_loss = 1.8140891790390015


100%|██████████| 5/5 [00:00<00:00,  7.14it/s]


Epoch = 54, val_loss = 2.055682897567749




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 55, train_loss = 1.8226922750473022


100%|██████████| 5/5 [00:00<00:00,  7.12it/s]


Epoch = 55, val_loss = 2.1611623764038086




100%|██████████| 15/15 [00:06<00:00,  2.43it/s]


Epoch = 56, train_loss = 1.8244922161102295


100%|██████████| 5/5 [00:00<00:00,  7.24it/s]


Epoch = 56, val_loss = 2.2826344966888428




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 57, train_loss = 1.845365285873413


100%|██████████| 5/5 [00:00<00:00,  7.47it/s]


Epoch = 57, val_loss = 2.1653473377227783




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 58, train_loss = 1.7991358041763306


100%|██████████| 5/5 [00:00<00:00,  7.42it/s]


Epoch = 58, val_loss = 2.1012275218963623




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 59, train_loss = 1.7907358407974243


100%|██████████| 5/5 [00:00<00:00,  6.92it/s]


Epoch = 59, val_loss = 2.173393964767456




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 60, train_loss = 1.779987096786499


100%|██████████| 5/5 [00:00<00:00,  7.27it/s]


Epoch = 60, val_loss = 2.0399138927459717




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 61, train_loss = 1.772046446800232


100%|██████████| 5/5 [00:00<00:00,  7.16it/s]


Epoch = 61, val_loss = 2.072239637374878




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 62, train_loss = 1.7324784994125366


100%|██████████| 5/5 [00:00<00:00,  7.33it/s]


Epoch = 62, val_loss = 2.115116596221924




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 63, train_loss = 1.7623580694198608


100%|██████████| 5/5 [00:00<00:00,  8.01it/s]


Epoch = 63, val_loss = 2.1008925437927246




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 64, train_loss = 1.7459684610366821


100%|██████████| 5/5 [00:00<00:00,  7.32it/s]


Epoch = 64, val_loss = 2.324049472808838




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 65, train_loss = 1.7326985597610474


100%|██████████| 5/5 [00:00<00:00,  7.82it/s]


Epoch = 65, val_loss = 2.0579757690429688




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 66, train_loss = 1.7025877237319946


100%|██████████| 5/5 [00:00<00:00,  7.46it/s]


Epoch = 66, val_loss = 2.018237590789795




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 67, train_loss = 1.6817022562026978


100%|██████████| 5/5 [00:00<00:00,  7.65it/s]


Epoch = 67, val_loss = 2.024317979812622




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 68, train_loss = 1.6873663663864136


100%|██████████| 5/5 [00:00<00:00,  7.63it/s]


Epoch = 68, val_loss = 2.0381107330322266




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 69, train_loss = 1.6890695095062256


100%|██████████| 5/5 [00:00<00:00,  7.57it/s]


Epoch = 69, val_loss = 2.331376791000366




100%|██████████| 15/15 [00:06<00:00,  2.48it/s]


Epoch = 70, train_loss = 1.6742559671401978


100%|██████████| 5/5 [00:00<00:00,  7.78it/s]


Epoch = 70, val_loss = 2.028528928756714




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 71, train_loss = 1.668691635131836


100%|██████████| 5/5 [00:00<00:00,  7.42it/s]


Epoch = 71, val_loss = 2.015164613723755




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 72, train_loss = 1.662764549255371


100%|██████████| 5/5 [00:00<00:00,  7.17it/s]


Epoch = 72, val_loss = 1.992177963256836




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 73, train_loss = 1.6614705324172974


100%|██████████| 5/5 [00:00<00:00,  7.41it/s]


Epoch = 73, val_loss = 2.1131367683410645




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 74, train_loss = 1.6610195636749268


100%|██████████| 5/5 [00:00<00:00,  7.24it/s]


Epoch = 74, val_loss = 2.174809217453003




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 75, train_loss = 1.6607873439788818


100%|██████████| 5/5 [00:00<00:00,  7.52it/s]


Epoch = 75, val_loss = 2.016545057296753




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 76, train_loss = 1.6608232259750366


100%|██████████| 5/5 [00:00<00:00,  7.58it/s]


Epoch = 76, val_loss = 1.9934622049331665




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 77, train_loss = 1.6610995531082153


100%|██████████| 5/5 [00:00<00:00,  7.60it/s]


Epoch = 77, val_loss = 2.0272200107574463




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 78, train_loss = 1.6583746671676636


100%|██████████| 5/5 [00:00<00:00,  7.43it/s]


Epoch = 78, val_loss = 2.048492193222046




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 79, train_loss = 1.658413290977478


100%|██████████| 5/5 [00:00<00:00,  7.35it/s]


Epoch = 79, val_loss = 2.075981378555298




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 80, train_loss = 1.6570818424224854


100%|██████████| 5/5 [00:00<00:00,  6.90it/s]


Epoch = 80, val_loss = 2.063758611679077




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 81, train_loss = 1.657843828201294


100%|██████████| 5/5 [00:00<00:00,  7.60it/s]


Epoch = 81, val_loss = 2.063566207885742




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 82, train_loss = 1.6574350595474243


100%|██████████| 5/5 [00:00<00:00,  7.79it/s]


Epoch = 82, val_loss = 2.1204774379730225




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 83, train_loss = 1.658055067062378


100%|██████████| 5/5 [00:00<00:00,  7.35it/s]


Epoch = 83, val_loss = 2.0273494720458984




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 84, train_loss = 1.656896948814392


100%|██████████| 5/5 [00:00<00:00,  7.36it/s]


Epoch = 84, val_loss = 2.024071216583252




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 85, train_loss = 1.6569082736968994


100%|██████████| 5/5 [00:00<00:00,  7.50it/s]


Epoch = 85, val_loss = 2.1213536262512207




100%|██████████| 15/15 [00:06<00:00,  2.43it/s]


Epoch = 86, train_loss = 1.6586968898773193


100%|██████████| 5/5 [00:00<00:00,  7.93it/s]


Epoch = 86, val_loss = 2.067608118057251




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 87, train_loss = 1.6569546461105347


100%|██████████| 5/5 [00:00<00:00,  7.34it/s]


Epoch = 87, val_loss = 2.033153772354126




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 88, train_loss = 1.6559118032455444


100%|██████████| 5/5 [00:00<00:00,  7.89it/s]


Epoch = 88, val_loss = 2.0321762561798096




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 89, train_loss = 1.657205581665039


100%|██████████| 5/5 [00:00<00:00,  7.76it/s]


Epoch = 89, val_loss = 2.0300471782684326




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 90, train_loss = 1.6553826332092285


100%|██████████| 5/5 [00:00<00:00,  7.23it/s]


Epoch = 90, val_loss = 1.9863882064819336




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 91, train_loss = 1.655705451965332


100%|██████████| 5/5 [00:00<00:00,  7.69it/s]


Epoch = 91, val_loss = 1.9902855157852173




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 92, train_loss = 1.6556566953659058


100%|██████████| 5/5 [00:00<00:00,  7.57it/s]


Epoch = 92, val_loss = 2.0387723445892334




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 93, train_loss = 1.6564122438430786


100%|██████████| 5/5 [00:00<00:00,  7.57it/s]


Epoch = 93, val_loss = 2.0381479263305664




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 94, train_loss = 1.6553518772125244


100%|██████████| 5/5 [00:00<00:00,  7.13it/s]


Epoch = 94, val_loss = 2.021913766860962




100%|██████████| 15/15 [00:06<00:00,  2.44it/s]


Epoch = 95, train_loss = 1.6555213928222656


100%|██████████| 5/5 [00:00<00:00,  7.37it/s]


Epoch = 95, val_loss = 2.0291659832000732




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 96, train_loss = 1.6560485363006592


100%|██████████| 5/5 [00:00<00:00,  7.59it/s]


Epoch = 96, val_loss = 2.0815954208374023




100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Epoch = 97, train_loss = 1.6560726165771484


100%|██████████| 5/5 [00:00<00:00,  7.28it/s]


Epoch = 97, val_loss = 2.0362229347229004




100%|██████████| 15/15 [00:06<00:00,  2.46it/s]


Epoch = 98, train_loss = 1.6556414365768433


100%|██████████| 5/5 [00:00<00:00,  7.46it/s]


Epoch = 98, val_loss = 1.9836835861206055




100%|██████████| 15/15 [00:06<00:00,  2.45it/s]


Epoch = 99, train_loss = 1.6549042463302612


100%|██████████| 5/5 [00:00<00:00,  7.59it/s]


Epoch = 99, val_loss = 1.9820705652236938




 33%|███▎      | 5/15 [00:02<00:04,  2.13it/s]


KeyboardInterrupt: 

## 사용 가능한 loss 함수

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self):
        super(LabelSmoothingCrossEntropy, self).__init__()
    def forward(self, x, target, smoothing=0.1):
        confidence = 1. - smoothing
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + smoothing * smooth_loss
        return loss.mean()