In [1]:
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
import numpy as np
from einops import rearrange
import torch
from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)

device = torch.device('cuda:0' if USE_CUDA else 'cpu')
print('학습을 진행하는 기기:', device)

True
학습을 진행하는 기기: cuda:0


In [2]:
with open('training_set.dat', "rb") as training_file:
    x = pickle.load(training_file)

In [9]:
with open('test_set.dat', "rb") as training_file:
    y = pickle.load(training_file)

In [3]:
x[0].shape

(60, 30, 224, 224, 3)

In [4]:
x[1]

array([0, 2, 1, 7, 1, 7, 5, 3, 1, 2, 6, 5, 3, 1, 0, 7, 8, 1, 8, 4, 9, 0,
       5, 2, 6, 4, 6, 5, 0, 2, 5, 9, 4, 3, 7, 0, 9, 6, 6, 3, 3, 2, 7, 5,
       7, 6, 4, 8, 4, 3, 9, 8, 4, 9, 8, 9, 1, 8, 2, 0])

In [5]:
def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    for i in range(len(labels)):
        one_hot_labels[i, labels[i]] = 1
    return one_hot_labels

In [7]:
temp_x = one_hot_encode(x[1], 10)
# temp_y = one_hot_encode(y[1], 10)

In [8]:
temp_x

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 

In [8]:
# print('temp x = ', temp_x)
# print('temp y = ', temp_y)

In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.hub import load_state_dict_from_url
import torchvision
from functools import partial
from collections import OrderedDict
import math

import os,inspect,sys

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0,currentdir)

def convert_relu_to_swish(model):
        for child_name, child in model.named_children():
            if isinstance(child, nn.ReLU):
                setattr(model, child_name, nn.SiLU(True))
                # setattr(model, child_name, Swish())
            else:
                convert_relu_to_swish(child)

class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mul_(torch.sigmoid(x))

class r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)
        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
        # self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)
        # out = self.softmax(out)
        return out

class flow_r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(flow_r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)

        model.stem[0] = nn.Conv3d(2, 45, kernel_size=(1, 7, 7),
                            stride=(1, 2, 2), padding=(0, 3, 3),
                            bias=False)

        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
    def forward(self, x):
        # print(x.size())
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)

        return out

In [12]:
model = r2plus1d_18(num_classes = 10).to(device)



In [11]:
# from torchinfo import summary

# summary(model, input_size = (4,3,15,224,224), col_names = ['input_size','output_size','num_params'], verbose=1)

In [12]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [13]:
transform = transforms.Compose([
    # transforms.ToPILImage(),
    # transforms.Resize((224,224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                      std=[0.229, 0.224, 0.225]),
])


In [14]:
class SignLanGuageDataset(Dataset):
    def __init__(self,imagedata,tagdata,transform):
        self.imagedata=imagedata
        self.tagdata=tagdata
        self.transform=transform

    def __len__(self):
        return len(self.imagedata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_data=(self.imagedata[idx])
        image_data=torch.FloatTensor(image_data)
        label=self.tagdata[idx]
        label = torch.FloatTensor(label)
        return image_data,label

In [15]:
batchsz = 4
num_workerssz = 4
epochs = 120

In [16]:
train_dataset = SignLanGuageDataset(imagedata=x[0],tagdata=temp_x,transform=transform)
valid_dataset = SignLanGuageDataset(imagedata=y[0],tagdata=temp_y,transform=transform)

In [17]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)

In [18]:
with torch.cuda.device(0):
    
    for epoch in range(100):
        train_avg_loss = 0
        val_avg_loss = 0
        
        model.train()
        for data, target in tqdm(train_dataloader):
            data = rearrange(data, 'b d h w c -> b c d h w')
            data = data.to(device)
            
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_avg_loss += loss
        train_avg_loss = train_avg_loss/len(train_dataloader)
        print('Epoch = {}, train_loss = {}'.format(epoch+1, train_avg_loss))
        with torch.no_grad(): # valid
            model.eval()
            for data, target in tqdm(valid_dataloader):
                data = rearrange(data, 'b d h w c -> b c d h w')
                data = data.to(device)
                
                target = target.to(device)
                
                hypothesis = model(data)
                val_loss = criterion(hypothesis, target)
                val_avg_loss += val_loss
                
            val_avg_loss = val_avg_loss/len(valid_dataloader)
        print('Epoch = {}, val_loss = {}'.format(epoch+1, val_avg_loss))
        print('\n')

  return F.conv3d(
100%|██████████| 15/15 [00:08<00:00,  1.79it/s]


Epoch = 1, train_loss = 2.4160335063934326


100%|██████████| 5/5 [00:00<00:00,  5.93it/s]


Epoch = 1, val_loss = 11.931220054626465




100%|██████████| 15/15 [00:06<00:00,  2.40it/s]


Epoch = 2, train_loss = 2.1489880084991455


100%|██████████| 5/5 [00:00<00:00,  6.07it/s]


Epoch = 2, val_loss = 2.6532907485961914




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 3, train_loss = 1.9059991836547852


100%|██████████| 5/5 [00:00<00:00,  5.86it/s]


Epoch = 3, val_loss = 2.3503129482269287




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 4, train_loss = 1.8548049926757812


100%|██████████| 5/5 [00:00<00:00,  5.71it/s]


Epoch = 4, val_loss = 3.3934338092803955




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 5, train_loss = 1.7383469343185425


100%|██████████| 5/5 [00:00<00:00,  5.83it/s]


Epoch = 5, val_loss = 2.906346321105957




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 6, train_loss = 1.5387890338897705


100%|██████████| 5/5 [00:00<00:00,  5.91it/s]


Epoch = 6, val_loss = 8.13010311126709




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 7, train_loss = 1.3452913761138916


100%|██████████| 5/5 [00:00<00:00,  5.45it/s]


Epoch = 7, val_loss = 1.7830876111984253




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 8, train_loss = 1.2601561546325684


100%|██████████| 5/5 [00:00<00:00,  6.05it/s]


Epoch = 8, val_loss = 1.9889116287231445




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 9, train_loss = 1.1651932001113892


100%|██████████| 5/5 [00:00<00:00,  5.93it/s]


Epoch = 9, val_loss = 4.57777738571167




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 10, train_loss = 1.2146214246749878


100%|██████████| 5/5 [00:00<00:00,  5.75it/s]


Epoch = 10, val_loss = 7.92144250869751




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 11, train_loss = 0.8644229173660278


100%|██████████| 5/5 [00:00<00:00,  5.57it/s]


Epoch = 11, val_loss = 2.088862180709839




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 12, train_loss = 0.6891402006149292


100%|██████████| 5/5 [00:00<00:00,  5.80it/s]


Epoch = 12, val_loss = 1.4790066480636597




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 13, train_loss = 0.7004570960998535


100%|██████████| 5/5 [00:00<00:00,  5.59it/s]


Epoch = 13, val_loss = 4.068557262420654




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 14, train_loss = 0.7041924595832825


100%|██████████| 5/5 [00:00<00:00,  5.71it/s]


Epoch = 14, val_loss = 2.4563748836517334




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 15, train_loss = 0.5988959074020386


100%|██████████| 5/5 [00:00<00:00,  5.49it/s]


Epoch = 15, val_loss = 1.2146397829055786




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 16, train_loss = 0.4517660140991211


100%|██████████| 5/5 [00:00<00:00,  5.86it/s]


Epoch = 16, val_loss = 0.9889117479324341




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 17, train_loss = 0.4522722065448761


100%|██████████| 5/5 [00:00<00:00,  5.90it/s]


Epoch = 17, val_loss = 5.947885990142822




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 18, train_loss = 0.3342623710632324


100%|██████████| 5/5 [00:00<00:00,  5.72it/s]


Epoch = 18, val_loss = 1.9094269275665283




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 19, train_loss = 0.3149493634700775


100%|██████████| 5/5 [00:00<00:00,  5.93it/s]


Epoch = 19, val_loss = 1.9564268589019775




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 20, train_loss = 0.24368487298488617


100%|██████████| 5/5 [00:00<00:00,  5.92it/s]


Epoch = 20, val_loss = 1.4787825345993042




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 21, train_loss = 0.14635637402534485


100%|██████████| 5/5 [00:00<00:00,  5.85it/s]


Epoch = 21, val_loss = 1.0110070705413818




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 22, train_loss = 0.08644018322229385


100%|██████████| 5/5 [00:00<00:00,  5.87it/s]


Epoch = 22, val_loss = 1.8408526182174683




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 23, train_loss = 0.055100876837968826


100%|██████████| 5/5 [00:00<00:00,  5.65it/s]


Epoch = 23, val_loss = 0.9419389963150024




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 24, train_loss = 0.0403081551194191


100%|██████████| 5/5 [00:00<00:00,  5.50it/s]


Epoch = 24, val_loss = 1.0021330118179321




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 25, train_loss = 0.029370686039328575


100%|██████████| 5/5 [00:00<00:00,  5.56it/s]


Epoch = 25, val_loss = 1.3245649337768555




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 26, train_loss = 0.026838155463337898


100%|██████████| 5/5 [00:00<00:00,  5.87it/s]


Epoch = 26, val_loss = 1.4944895505905151




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 27, train_loss = 0.020498525351285934


100%|██████████| 5/5 [00:00<00:00,  5.82it/s]


Epoch = 27, val_loss = 1.3431674242019653




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 28, train_loss = 0.01847081258893013


100%|██████████| 5/5 [00:00<00:00,  5.90it/s]


Epoch = 28, val_loss = 1.301395297050476




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 29, train_loss = 0.01419414859265089


100%|██████████| 5/5 [00:00<00:00,  5.73it/s]


Epoch = 29, val_loss = 1.3374444246292114




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 30, train_loss = 0.015171532519161701


100%|██████████| 5/5 [00:00<00:00,  5.80it/s]


Epoch = 30, val_loss = 1.4498144388198853




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 31, train_loss = 0.014677413739264011


100%|██████████| 5/5 [00:00<00:00,  5.62it/s]


Epoch = 31, val_loss = 1.3756043910980225




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 32, train_loss = 0.015622925013303757


100%|██████████| 5/5 [00:00<00:00,  5.78it/s]


Epoch = 32, val_loss = 1.3041496276855469




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 33, train_loss = 0.013197997584939003


100%|██████████| 5/5 [00:00<00:00,  5.44it/s]


Epoch = 33, val_loss = 1.2639678716659546




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 34, train_loss = 0.011219281703233719


100%|██████████| 5/5 [00:00<00:00,  5.77it/s]


Epoch = 34, val_loss = 1.5406392812728882




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 35, train_loss = 0.012173979543149471


100%|██████████| 5/5 [00:00<00:00,  5.08it/s]


Epoch = 35, val_loss = 1.566688060760498




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 36, train_loss = 0.011245914734899998


100%|██████████| 5/5 [00:00<00:00,  5.70it/s]


Epoch = 36, val_loss = 1.4797176122665405




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 37, train_loss = 0.008221768774092197


100%|██████████| 5/5 [00:00<00:00,  5.50it/s]


Epoch = 37, val_loss = 1.460831880569458




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 38, train_loss = 0.008469808846712112


100%|██████████| 5/5 [00:00<00:00,  5.45it/s]


Epoch = 38, val_loss = 1.6732017993927002




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 39, train_loss = 0.007255190052092075


100%|██████████| 5/5 [00:00<00:00,  5.51it/s]


Epoch = 39, val_loss = 1.784623146057129




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 40, train_loss = 0.006318945903331041


100%|██████████| 5/5 [00:00<00:00,  5.63it/s]


Epoch = 40, val_loss = 1.6164028644561768




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 41, train_loss = 0.006014461629092693


100%|██████████| 5/5 [00:00<00:00,  5.73it/s]


Epoch = 41, val_loss = 1.2586919069290161




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 42, train_loss = 0.006716050207614899


100%|██████████| 5/5 [00:00<00:00,  5.79it/s]


Epoch = 42, val_loss = 1.5392858982086182




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 43, train_loss = 0.006695104297250509


100%|██████████| 5/5 [00:00<00:00,  5.94it/s]


Epoch = 43, val_loss = 1.2309151887893677




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 44, train_loss = 0.005531140603125095


100%|██████████| 5/5 [00:00<00:00,  5.69it/s]


Epoch = 44, val_loss = 1.6077953577041626




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 45, train_loss = 0.00633389875292778


100%|██████████| 5/5 [00:00<00:00,  5.68it/s]


Epoch = 45, val_loss = 1.5722216367721558




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 46, train_loss = 0.008595102466642857


100%|██████████| 5/5 [00:00<00:00,  5.83it/s]


Epoch = 46, val_loss = 1.331290602684021




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 47, train_loss = 0.021190989762544632


100%|██████████| 5/5 [00:00<00:00,  5.92it/s]


Epoch = 47, val_loss = 1.3663866519927979




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 48, train_loss = 0.07846489548683167


100%|██████████| 5/5 [00:00<00:00,  5.54it/s]


Epoch = 48, val_loss = 7.71707010269165




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 49, train_loss = 0.7091789841651917


100%|██████████| 5/5 [00:00<00:00,  5.55it/s]


Epoch = 49, val_loss = 98.00215148925781




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 50, train_loss = 1.1614453792572021


100%|██████████| 5/5 [00:00<00:00,  5.77it/s]


Epoch = 50, val_loss = 66.62312316894531




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 51, train_loss = 0.7940191626548767


100%|██████████| 5/5 [00:00<00:00,  6.02it/s]


Epoch = 51, val_loss = 4.979474067687988




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 52, train_loss = 0.36579447984695435


100%|██████████| 5/5 [00:00<00:00,  5.87it/s]


Epoch = 52, val_loss = 1.0668420791625977




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 53, train_loss = 0.17888137698173523


100%|██████████| 5/5 [00:00<00:00,  5.92it/s]


Epoch = 53, val_loss = 1.156010389328003




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 54, train_loss = 0.14433351159095764


100%|██████████| 5/5 [00:00<00:00,  5.86it/s]


Epoch = 54, val_loss = 0.8743433356285095




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 55, train_loss = 0.10675950348377228


100%|██████████| 5/5 [00:00<00:00,  5.86it/s]


Epoch = 55, val_loss = 0.809332013130188




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 56, train_loss = 0.08238110691308975


100%|██████████| 5/5 [00:00<00:00,  5.82it/s]


Epoch = 56, val_loss = 0.7332069873809814




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 57, train_loss = 0.04753894731402397


100%|██████████| 5/5 [00:00<00:00,  5.79it/s]


Epoch = 57, val_loss = 0.6186617612838745




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 58, train_loss = 0.04448530077934265


100%|██████████| 5/5 [00:00<00:00,  5.60it/s]


Epoch = 58, val_loss = 1.0983402729034424




100%|██████████| 15/15 [00:06<00:00,  2.34it/s]


Epoch = 59, train_loss = 0.027877112850546837


100%|██████████| 5/5 [00:00<00:00,  5.55it/s]


Epoch = 59, val_loss = 0.9325658082962036




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 60, train_loss = 0.014953302219510078


100%|██████████| 5/5 [00:00<00:00,  5.64it/s]


Epoch = 60, val_loss = 0.9012727737426758




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 61, train_loss = 0.016875093802809715


100%|██████████| 5/5 [00:00<00:00,  5.12it/s]


Epoch = 61, val_loss = 0.7975407242774963




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 62, train_loss = 0.014530204236507416


100%|██████████| 5/5 [00:00<00:00,  5.44it/s]


Epoch = 62, val_loss = 0.829734742641449




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 63, train_loss = 0.013680838979780674


100%|██████████| 5/5 [00:00<00:00,  5.47it/s]


Epoch = 63, val_loss = 0.8728739023208618




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 64, train_loss = 0.011543575674295425


100%|██████████| 5/5 [00:00<00:00,  5.63it/s]


Epoch = 64, val_loss = 1.079216718673706




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 65, train_loss = 0.011808395385742188


100%|██████████| 5/5 [00:00<00:00,  5.51it/s]


Epoch = 65, val_loss = 1.0048657655715942




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 66, train_loss = 0.008440357632935047


100%|██████████| 5/5 [00:00<00:00,  5.39it/s]


Epoch = 66, val_loss = 0.9138802886009216




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 67, train_loss = 0.00909644179046154


100%|██████████| 5/5 [00:00<00:00,  5.66it/s]


Epoch = 67, val_loss = 0.9085180163383484




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 68, train_loss = 0.007109988015145063


100%|██████████| 5/5 [00:00<00:00,  5.99it/s]


Epoch = 68, val_loss = 0.9834467768669128




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 69, train_loss = 0.0069776661694049835


100%|██████████| 5/5 [00:00<00:00,  5.67it/s]


Epoch = 69, val_loss = 1.0071998834609985




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 70, train_loss = 0.009139245375990868


100%|██████████| 5/5 [00:00<00:00,  5.69it/s]


Epoch = 70, val_loss = 0.9333338737487793




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 71, train_loss = 0.006980872247368097


100%|██████████| 5/5 [00:00<00:00,  5.97it/s]


Epoch = 71, val_loss = 0.9169666171073914




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 72, train_loss = 0.00693886075168848


100%|██████████| 5/5 [00:00<00:00,  5.66it/s]


Epoch = 72, val_loss = 1.0433452129364014




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 73, train_loss = 0.0069983950816094875


100%|██████████| 5/5 [00:00<00:00,  5.71it/s]


Epoch = 73, val_loss = 1.0166430473327637




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 74, train_loss = 0.005408319178968668


100%|██████████| 5/5 [00:00<00:00,  6.15it/s]


Epoch = 74, val_loss = 1.0407402515411377




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 75, train_loss = 0.00512304762378335


100%|██████████| 5/5 [00:00<00:00,  5.66it/s]


Epoch = 75, val_loss = 1.0628902912139893




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 76, train_loss = 0.005486120469868183


100%|██████████| 5/5 [00:00<00:00,  5.94it/s]


Epoch = 76, val_loss = 1.0272200107574463




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 77, train_loss = 0.005730454809963703


100%|██████████| 5/5 [00:00<00:00,  5.77it/s]


Epoch = 77, val_loss = 1.0460776090621948




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 78, train_loss = 0.004747290164232254


100%|██████████| 5/5 [00:00<00:00,  5.83it/s]


Epoch = 78, val_loss = 1.0987138748168945




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 79, train_loss = 0.004972358699887991


100%|██████████| 5/5 [00:00<00:00,  5.70it/s]


Epoch = 79, val_loss = 1.089491844177246




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 80, train_loss = 0.005316173192113638


100%|██████████| 5/5 [00:00<00:00,  5.66it/s]


Epoch = 80, val_loss = 1.099668264389038




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 81, train_loss = 0.004877304658293724


100%|██████████| 5/5 [00:00<00:00,  5.78it/s]


Epoch = 81, val_loss = 1.1698641777038574




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 82, train_loss = 0.004840915556997061


100%|██████████| 5/5 [00:00<00:00,  5.35it/s]


Epoch = 82, val_loss = 1.130767822265625




100%|██████████| 15/15 [00:06<00:00,  2.34it/s]


Epoch = 83, train_loss = 0.0035626261960715055


100%|██████████| 5/5 [00:00<00:00,  5.60it/s]


Epoch = 83, val_loss = 1.1249744892120361




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 84, train_loss = 0.003480886109173298


100%|██████████| 5/5 [00:00<00:00,  5.45it/s]


Epoch = 84, val_loss = 1.0914009809494019




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 85, train_loss = 0.0034269141033291817


100%|██████████| 5/5 [00:00<00:00,  5.33it/s]


Epoch = 85, val_loss = 1.1582238674163818




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 86, train_loss = 0.0033161393366754055


100%|██████████| 5/5 [00:00<00:00,  5.77it/s]


Epoch = 86, val_loss = 1.11613130569458




100%|██████████| 15/15 [00:06<00:00,  2.35it/s]


Epoch = 87, train_loss = 0.003637814661487937


100%|██████████| 5/5 [00:00<00:00,  5.81it/s]


Epoch = 87, val_loss = 1.1135791540145874




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 88, train_loss = 0.004024864174425602


100%|██████████| 5/5 [00:00<00:00,  5.56it/s]


Epoch = 88, val_loss = 1.2771726846694946




100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 89, train_loss = 0.0027834575157612562


100%|██████████| 5/5 [00:00<00:00,  6.04it/s]


Epoch = 89, val_loss = 1.224900722503662




100%|██████████| 15/15 [00:06<00:00,  2.39it/s]


Epoch = 90, train_loss = 0.003970739431679249


100%|██████████| 5/5 [00:00<00:00,  5.99it/s]


Epoch = 90, val_loss = 1.0570605993270874




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 91, train_loss = 0.002399043645709753


100%|██████████| 5/5 [00:00<00:00,  5.96it/s]


Epoch = 91, val_loss = 1.1016093492507935




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 92, train_loss = 0.002714322181418538


100%|██████████| 5/5 [00:00<00:00,  6.02it/s]


Epoch = 92, val_loss = 1.1343985795974731




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 93, train_loss = 0.0020295821595937014


100%|██████████| 5/5 [00:00<00:00,  5.61it/s]


Epoch = 93, val_loss = 1.1831663846969604




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 94, train_loss = 0.0024442258290946484


100%|██████████| 5/5 [00:00<00:00,  5.80it/s]


Epoch = 94, val_loss = 1.1116749048233032




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 95, train_loss = 0.0027783445548266172


100%|██████████| 5/5 [00:00<00:00,  5.91it/s]


Epoch = 95, val_loss = 1.1124904155731201




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 96, train_loss = 0.002611847361549735


100%|██████████| 5/5 [00:00<00:00,  5.71it/s]


Epoch = 96, val_loss = 1.1231354475021362




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 97, train_loss = 0.0024531143717467785


100%|██████████| 5/5 [00:00<00:00,  5.66it/s]


Epoch = 97, val_loss = 1.1013611555099487




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 98, train_loss = 0.002182362135499716


100%|██████████| 5/5 [00:00<00:00,  5.63it/s]


Epoch = 98, val_loss = 1.084539771080017




100%|██████████| 15/15 [00:06<00:00,  2.36it/s]


Epoch = 99, train_loss = 0.002274280646815896


100%|██████████| 5/5 [00:00<00:00,  6.18it/s]


Epoch = 99, val_loss = 1.2012542486190796




100%|██████████| 15/15 [00:06<00:00,  2.38it/s]


Epoch = 100, train_loss = 0.0018757020588964224


100%|██████████| 5/5 [00:00<00:00,  5.80it/s]

Epoch = 100, val_loss = 1.2451039552688599







In [19]:
torch.save(model.state_dict(), 'model_weights.pth')

In [20]:
valid_dataset[0][0].shape

torch.Size([30, 224, 224, 3])

In [21]:
len(valid_dataset)

20

In [27]:
with torch.no_grad():
    model.eval()
    for data, target in tqdm(valid_dataloader):
        data = rearrange(data, 'b d h w c -> b c d h w')
        data = data.to(device)
        
        target = target.to(device)
        
        hypothesis = model(data)
        
        # print('target = ', target,'\n', 'hypothesis = ', hypothesis, '\n\n\n\n')
        for i in range(batchsz):
            print('target = ', torch.argmax(target[i]),'\n', 'hypothesis = ', torch.argmax(hypothesis[i]), '\n\n\n\n')           

 20%|██        | 1/5 [00:00<00:01,  2.31it/s]

target =  tensor(9, device='cuda:0') 
 hypothesis =  tensor(8, device='cuda:0') 




target =  tensor(4, device='cuda:0') 
 hypothesis =  tensor(4, device='cuda:0') 




target =  tensor(0, device='cuda:0') 
 hypothesis =  tensor(0, device='cuda:0') 




target =  tensor(7, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  

 60%|██████    | 3/5 [00:00<00:00,  5.17it/s]

tensor(5, device='cuda:0') 
 hypothesis =  tensor(5, device='cuda:0') 




target =  tensor(9, device='cuda:0') 
 hypothesis =  tensor(8, device='cuda:0') 




target =  tensor(6, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(8, device='cuda:0') 
 hypothesis =  tensor(9, device='cuda:0') 




target =  tensor(4, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(3, device='cuda:0') 
 hypothesis =  tensor(3, device='cuda:0') 




target =  tensor(7, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(8, device='cuda:0') 
 hypothesis =  tensor(8, device='cuda:0') 




target =  

100%|██████████| 5/5 [00:00<00:00,  6.60it/s]

tensor(2, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(1, device='cuda:0') 
 hypothesis =  tensor(7, device='cuda:0') 




target =  tensor(3, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(6, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(1, device='cuda:0') 
 hypothesis =  tensor(1, device='cuda:0') 




target =  tensor(0, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(2, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(5, device='cuda:0') 
 hypothesis =  tensor(5, device='cuda:0') 






100%|██████████| 5/5 [00:01<00:00,  4.89it/s]


In [23]:
# len(valid_dataset)

In [24]:
# valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=20, shuffle=False, num_workers=num_workerssz)

# with torch.no_grad(): # valid
#             model.eval()
#             for data, target in tqdm(valid_dataloader):
#                 data = rearrange(data, 'b d h w c -> b c d h w')
#                 data = data.to(device)
                
#                 target = target.to(device)
                
#                 hypothesis = model(data)

In [25]:
hypothesis.shape

torch.Size([4, 10])

## 사용 가능한 loss 함수

In [26]:
# class LabelSmoothingCrossEntropy(nn.Module):
#     def __init__(self):
#         super(LabelSmoothingCrossEntropy, self).__init__()
#     def forward(self, x, target, smoothing=0.1):
#         confidence = 1. - smoothing
#         logprobs = F.log_softmax(x, dim=-1)
#         nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
#         nll_loss = nll_loss.squeeze(1)
#         smooth_loss = -logprobs.mean(dim=-1)
#         loss = confidence * nll_loss + smoothing * smooth_loss
#         return loss.mean()