In [1]:
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
import numpy as np
from einops import rearrange
import torch
from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import gc

USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)

device = torch.device('cuda:0' if USE_CUDA else 'cpu')
print('학습을 진행하는 기기:', device)

save_model_file_path = './save_model/{}_{}.{}'

True
학습을 진행하는 기기: cuda:0


In [2]:
with open('./processed_data/train/training_set.dat', "rb") as training_file:
    x1 = pickle.load(training_file)
with open('./processed_data/train/training_set(ElasticTransformation).dat', "rb") as training_file:
    x2 = pickle.load(training_file)
with open('./processed_data/train/training_set(GausianBlur).dat', "rb") as training_file:
    x3 = pickle.load(training_file)
with open('./processed_data/train/training_set(Salt).dat', "rb") as training_file:
    x4 = pickle.load(training_file)

In [3]:
x = x1+ x2+ x3+ x4
del x1, x2, x3, x4
gc.collect()

0

In [4]:
with open('./processed_data/test/test_set.dat', "rb") as training_file:
    y = pickle.load(training_file)

In [5]:
def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    for i in range(len(labels)):
        one_hot_labels[i, labels[i]] = 1
    return one_hot_labels

In [6]:
temp_x = one_hot_encode(x[1], 10)
temp_y = one_hot_encode(y[1], 10)

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.hub import load_state_dict_from_url
import torchvision
from functools import partial
from collections import OrderedDict
import math

import os,inspect,sys

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0,currentdir)

def convert_relu_to_swish(model):
        for child_name, child in model.named_children():
            if isinstance(child, nn.ReLU):
                setattr(model, child_name, nn.SiLU(True))
                # setattr(model, child_name, Swish())
            else:
                convert_relu_to_swish(child)

class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mul_(torch.sigmoid(x))

class r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)
        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
        # self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)
        # out = self.softmax(out)
        return out

class flow_r2plus1d_18(nn.Module):
    def __init__(self, pretrained=False, num_classes=10, dropout_p=0.5):
        super(flow_r2plus1d_18, self).__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        model = torchvision.models.video.r2plus1d_18(pretrained=self.pretrained)

        model.stem[0] = nn.Conv3d(2, 45, kernel_size=(1, 7, 7),
                            stride=(1, 2, 2), padding=(0, 3, 3),
                            bias=False)

        # delete the last fc layer
        modules = list(model.children())[:-1]
        # print(modules)
        self.r2plus1d_18 = nn.Sequential(*modules)
        convert_relu_to_swish(self.r2plus1d_18)
        self.fc1 = nn.Linear(model.fc.in_features, self.num_classes)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
    def forward(self, x):
        # print(x.size())
        out = self.r2plus1d_18(x)
        # print(out.shape)
        # Flatten the layer to fc
        out = out.flatten(1)
        out = self.dropout(out)
        out = self.fc1(out)

        return out

In [8]:
model = r2plus1d_18(num_classes = 10).to(device)



In [9]:
# from torchinfo import summary

# summary(model, input_size = (4,3,15,224,224), col_names = ['input_size','output_size','num_params'], verbose=1)

In [10]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [11]:
# train_transform = A.Compose([
#     A.GaussNoise(always_apply=False, p = 0.3, var_limit = (50.00, 100.00), per_channel = True, mean = 0.0),
#     A.RGBShift(always_apply=False, p = 0.3, r_shift_limit=(-10,10), g_shift_limit=(-10,10), b_shift_limit=(-10,10))])

# test_transform = A.Compose([
    
# ])

In [12]:
class SignLanGuageDataset(Dataset):
    def __init__(self,imagedata,tagdata):
        self.imagedata=imagedata
        self.tagdata=tagdata

    def __len__(self):
        return len(self.imagedata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_data=(self.imagedata[idx])
        image_data=torch.FloatTensor(image_data)
        label=self.tagdata[idx]
        label = torch.FloatTensor(label)
        return image_data,label

In [13]:
batchsz = 4
num_workerssz = 4
epochs = 120

In [14]:
train_dataset = SignLanGuageDataset(imagedata=x[0],tagdata=temp_x)
valid_dataset = SignLanGuageDataset(imagedata=y[0],tagdata=temp_y)

In [15]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batchsz, shuffle=False, num_workers=num_workerssz)

In [16]:
with torch.cuda.device(0):
    
    for epoch in range(60):
        train_avg_loss = 0
        val_avg_loss = 0
        
        model.train()
        for data, target in tqdm(train_dataloader):
            data = rearrange(data, 'b d h w c -> b c d h w')
            data = data.to(device)
            
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_avg_loss += loss
        train_avg_loss = train_avg_loss/len(train_dataloader)
        print('Epoch = {}, train_loss = {}'.format(epoch+1, train_avg_loss))
        with torch.no_grad(): # valid
            model.eval()
            for data, target in tqdm(valid_dataloader):
                data = rearrange(data, 'b d h w c -> b c d h w')
                data = data.to(device)
                
                target = target.to(device)
                
                hypothesis = model(data)
                val_loss = criterion(hypothesis, target)
                val_avg_loss += val_loss
                
            val_avg_loss = val_avg_loss/len(valid_dataloader)

        torch.save({
            'epoch': epoch+1,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, save_model_file_path.format('model',(epoch+1),'pth'))    
        
        print('Epoch = {}, val_loss = {}'.format(epoch+1, val_avg_loss))
        print('\n')

100%|██████████| 15/15 [00:06<00:00,  2.37it/s]


Epoch = 1, train_loss = 2.432539701461792


100%|██████████| 5/5 [00:00<00:00,  8.66it/s]


Epoch = 1, val_loss = 6.186643123626709




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 2, train_loss = 2.297030448913574


100%|██████████| 5/5 [00:00<00:00,  8.52it/s]


Epoch = 2, val_loss = 2.683819532394409




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 3, train_loss = 2.262098550796509


100%|██████████| 5/5 [00:00<00:00,  8.39it/s]


Epoch = 3, val_loss = 2.282637357711792




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 4, train_loss = 2.1446237564086914


100%|██████████| 5/5 [00:00<00:00,  8.32it/s]


Epoch = 4, val_loss = 2.1256368160247803




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 5, train_loss = 1.948306918144226


100%|██████████| 5/5 [00:00<00:00,  8.19it/s]


Epoch = 5, val_loss = 2.0744597911834717




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 6, train_loss = 1.7244243621826172


100%|██████████| 5/5 [00:00<00:00,  8.54it/s]


Epoch = 6, val_loss = 6.6292595863342285




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 7, train_loss = 1.5010477304458618


100%|██████████| 5/5 [00:00<00:00,  8.39it/s]


Epoch = 7, val_loss = 3.1442625522613525




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 8, train_loss = 1.4638675451278687


100%|██████████| 5/5 [00:00<00:00,  8.22it/s]


Epoch = 8, val_loss = 2.345686674118042




100%|██████████| 15/15 [00:06<00:00,  2.49it/s]


Epoch = 9, train_loss = 1.2779802083969116


100%|██████████| 5/5 [00:00<00:00,  8.29it/s]


Epoch = 9, val_loss = 2.124040365219116




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 10, train_loss = 1.2055960893630981


100%|██████████| 5/5 [00:00<00:00,  8.63it/s]


Epoch = 10, val_loss = 8.43392562866211




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 11, train_loss = 0.9619966149330139


100%|██████████| 5/5 [00:00<00:00,  8.52it/s]


Epoch = 11, val_loss = 1.7002123594284058




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 12, train_loss = 0.9291332364082336


100%|██████████| 5/5 [00:00<00:00,  8.53it/s]


Epoch = 12, val_loss = 1.8751020431518555




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 13, train_loss = 0.792033314704895


100%|██████████| 5/5 [00:00<00:00,  8.61it/s]


Epoch = 13, val_loss = 2.4835827350616455




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 14, train_loss = 0.6511948704719543


100%|██████████| 5/5 [00:00<00:00,  8.37it/s]


Epoch = 14, val_loss = 1.4317898750305176




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 15, train_loss = 0.5783464312553406


100%|██████████| 5/5 [00:00<00:00,  8.17it/s]


Epoch = 15, val_loss = 1.5212472677230835




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 16, train_loss = 0.5268725156784058


100%|██████████| 5/5 [00:00<00:00,  8.53it/s]


Epoch = 16, val_loss = 1.8650015592575073




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 17, train_loss = 0.5098922252655029


100%|██████████| 5/5 [00:00<00:00,  8.52it/s]


Epoch = 17, val_loss = 2.5757243633270264




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 18, train_loss = 0.5426392555236816


100%|██████████| 5/5 [00:00<00:00,  8.51it/s]


Epoch = 18, val_loss = 1.5441057682037354




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 19, train_loss = 0.4227679371833801


100%|██████████| 5/5 [00:00<00:00,  8.15it/s]


Epoch = 19, val_loss = 3.734363555908203




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 20, train_loss = 0.4359978437423706


100%|██████████| 5/5 [00:00<00:00,  8.38it/s]


Epoch = 20, val_loss = 1.343965768814087




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 21, train_loss = 0.32944726943969727


100%|██████████| 5/5 [00:00<00:00,  8.53it/s]


Epoch = 21, val_loss = 1.7021633386611938




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 22, train_loss = 0.2835558354854584


100%|██████████| 5/5 [00:00<00:00,  8.46it/s]


Epoch = 22, val_loss = 1.7774044275283813




100%|██████████| 15/15 [00:06<00:00,  2.49it/s]


Epoch = 23, train_loss = 0.2864803075790405


100%|██████████| 5/5 [00:00<00:00,  8.25it/s]


Epoch = 23, val_loss = 2.209841251373291




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 24, train_loss = 0.2930375635623932


100%|██████████| 5/5 [00:00<00:00,  8.10it/s]


Epoch = 24, val_loss = 3.4598443508148193




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 25, train_loss = 0.2759188115596771


100%|██████████| 5/5 [00:00<00:00,  8.28it/s]


Epoch = 25, val_loss = 1.728043556213379




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 26, train_loss = 0.2290327101945877


100%|██████████| 5/5 [00:00<00:00,  8.74it/s]


Epoch = 26, val_loss = 1.7059965133666992




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 27, train_loss = 0.15530064702033997


100%|██████████| 5/5 [00:00<00:00,  8.49it/s]


Epoch = 27, val_loss = 2.233234167098999




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 28, train_loss = 0.08726318180561066


100%|██████████| 5/5 [00:00<00:00,  8.33it/s]


Epoch = 28, val_loss = 1.2486401796340942




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 29, train_loss = 0.06796178966760635


100%|██████████| 5/5 [00:00<00:00,  8.16it/s]


Epoch = 29, val_loss = 1.0452854633331299




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 30, train_loss = 0.05487263202667236


100%|██████████| 5/5 [00:00<00:00,  8.34it/s]


Epoch = 30, val_loss = 1.0498318672180176




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 31, train_loss = 0.04451386630535126


100%|██████████| 5/5 [00:00<00:00,  8.60it/s]


Epoch = 31, val_loss = 1.125334620475769




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 32, train_loss = 0.02910732664167881


100%|██████████| 5/5 [00:00<00:00,  8.51it/s]


Epoch = 32, val_loss = 1.098948359489441




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 33, train_loss = 0.02922758087515831


100%|██████████| 5/5 [00:00<00:00,  8.51it/s]


Epoch = 33, val_loss = 1.179972767829895




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 34, train_loss = 0.03195153549313545


100%|██████████| 5/5 [00:00<00:00,  8.61it/s]


Epoch = 34, val_loss = 1.1181044578552246




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 35, train_loss = 0.023805860430002213


100%|██████████| 5/5 [00:00<00:00,  8.51it/s]


Epoch = 35, val_loss = 1.1862115859985352




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 36, train_loss = 0.01820193976163864


100%|██████████| 5/5 [00:00<00:00,  8.17it/s]


Epoch = 36, val_loss = 1.128761649131775




100%|██████████| 15/15 [00:06<00:00,  2.49it/s]


Epoch = 37, train_loss = 0.016394924372434616


100%|██████████| 5/5 [00:00<00:00,  8.36it/s]


Epoch = 37, val_loss = 1.1799695491790771




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 38, train_loss = 0.015130282379686832


100%|██████████| 5/5 [00:00<00:00,  8.49it/s]


Epoch = 38, val_loss = 1.2163070440292358




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 39, train_loss = 0.013581734150648117


100%|██████████| 5/5 [00:00<00:00,  8.50it/s]


Epoch = 39, val_loss = 1.1621408462524414




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 40, train_loss = 0.014049182645976543


100%|██████████| 5/5 [00:00<00:00,  8.57it/s]


Epoch = 40, val_loss = 1.1532937288284302




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 41, train_loss = 0.014676941558718681


100%|██████████| 5/5 [00:00<00:00,  8.17it/s]


Epoch = 41, val_loss = 1.2517588138580322




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 42, train_loss = 0.010998254641890526


100%|██████████| 5/5 [00:00<00:00,  8.65it/s]


Epoch = 42, val_loss = 1.2727807760238647




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 43, train_loss = 0.008947291411459446


100%|██████████| 5/5 [00:00<00:00,  8.44it/s]


Epoch = 43, val_loss = 1.2886948585510254




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 44, train_loss = 0.009341876022517681


100%|██████████| 5/5 [00:00<00:00,  8.59it/s]


Epoch = 44, val_loss = 1.3076504468917847




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 45, train_loss = 0.009998375549912453


100%|██████████| 5/5 [00:00<00:00,  8.43it/s]


Epoch = 45, val_loss = 1.3146206140518188




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 46, train_loss = 0.00874705333262682


100%|██████████| 5/5 [00:00<00:00,  8.49it/s]


Epoch = 46, val_loss = 1.292404055595398




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 47, train_loss = 0.009360600262880325


100%|██████████| 5/5 [00:00<00:00,  8.53it/s]


Epoch = 47, val_loss = 1.2926453351974487




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 48, train_loss = 0.008051150478422642


100%|██████████| 5/5 [00:00<00:00,  8.72it/s]


Epoch = 48, val_loss = 1.269784927368164




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 49, train_loss = 0.0068405200727283955


100%|██████████| 5/5 [00:00<00:00,  8.58it/s]


Epoch = 49, val_loss = 1.4100618362426758




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 50, train_loss = 0.007001752033829689


100%|██████████| 5/5 [00:00<00:00,  8.59it/s]


Epoch = 50, val_loss = 1.3946852684020996




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 51, train_loss = 0.0062575917690992355


100%|██████████| 5/5 [00:00<00:00,  8.24it/s]


Epoch = 51, val_loss = 1.3680626153945923




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 52, train_loss = 0.006395444739609957


100%|██████████| 5/5 [00:00<00:00,  8.70it/s]


Epoch = 52, val_loss = 1.3586698770523071




100%|██████████| 15/15 [00:06<00:00,  2.50it/s]


Epoch = 53, train_loss = 0.005759020335972309


100%|██████████| 5/5 [00:00<00:00,  8.30it/s]


Epoch = 53, val_loss = 1.4385961294174194




100%|██████████| 15/15 [00:06<00:00,  2.49it/s]


Epoch = 54, train_loss = 0.004748825449496508


100%|██████████| 5/5 [00:00<00:00,  8.57it/s]


Epoch = 54, val_loss = 1.3582733869552612




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 55, train_loss = 0.0065528941340744495


100%|██████████| 5/5 [00:00<00:00,  8.20it/s]


Epoch = 55, val_loss = 1.402043342590332




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 56, train_loss = 0.005507077090442181


100%|██████████| 5/5 [00:00<00:00,  8.57it/s]


Epoch = 56, val_loss = 1.4072571992874146




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 57, train_loss = 0.0051718479953706264


100%|██████████| 5/5 [00:00<00:00,  8.57it/s]


Epoch = 57, val_loss = 1.5362600088119507




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 58, train_loss = 0.0038624899461865425


100%|██████████| 5/5 [00:00<00:00,  8.62it/s]


Epoch = 58, val_loss = 1.6033588647842407




100%|██████████| 15/15 [00:05<00:00,  2.51it/s]


Epoch = 59, train_loss = 0.005254645366221666


100%|██████████| 5/5 [00:00<00:00,  8.16it/s]


Epoch = 59, val_loss = 1.3906329870224




100%|██████████| 15/15 [00:05<00:00,  2.50it/s]


Epoch = 60, train_loss = 0.004456115886569023


100%|██████████| 5/5 [00:00<00:00,  8.57it/s]


Epoch = 60, val_loss = 1.4799762964248657




In [None]:
# checkpoint=torch.load("./save_model/model_33.pth", map_location=device)
# model.load_state_dict(checkpoint["model"])
# optimizer.load_state_dict(checkpoint["optimizer"])
# start_epoch = checkpoint['epoch']

In [None]:
with torch.no_grad():
    model.eval()
    for data, target in tqdm(valid_dataloader):
        data = rearrange(data, 'b d h w c -> b c d h w')
        data = data.to(device)
        
        target = target.to(device)
        
        hypothesis = model(data)
        
        # print('target = ', target,'\n', 'hypothesis = ', hypothesis, '\n\n\n\n')
        for i in range(batchsz):
            print('target = ', torch.argmax(target[i]),'\n', 'hypothesis = ', torch.argmax(hypothesis[i]), '\n\n\n\n')           

 20%|██        | 1/5 [00:00<00:00,  4.77it/s]

target =  tensor(9, device='cuda:0') 
 hypothesis =  tensor(9, device='cuda:0') 




target =  tensor(4, device='cuda:0') 
 hypothesis =  tensor(4, device='cuda:0') 




target =  tensor(0, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(7, device='cuda:0') 
 hypothesis =  tensor(7, device='cuda:0') 




target =  

 60%|██████    | 3/5 [00:00<00:00,  7.19it/s]

tensor(5, device='cuda:0') 
 hypothesis =  tensor(5, device='cuda:0') 




target =  tensor(9, device='cuda:0') 
 hypothesis =  tensor(9, device='cuda:0') 




target =  tensor(6, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(8, device='cuda:0') 
 hypothesis =  tensor(9, device='cuda:0') 




target =  tensor(4, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(3, device='cuda:0') 
 hypothesis =  tensor(3, device='cuda:0') 




target =  tensor(7, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(8, device='cuda:0') 
 hypothesis =  tensor(9, device='cuda:0') 




target =  

100%|██████████| 5/5 [00:00<00:00,  6.91it/s]

tensor(2, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(1, device='cuda:0') 
 hypothesis =  tensor(1, device='cuda:0') 




target =  tensor(3, device='cuda:0') 
 hypothesis =  tensor(3, device='cuda:0') 




target =  tensor(6, device='cuda:0') 
 hypothesis =  tensor(6, device='cuda:0') 




target =  tensor(1, device='cuda:0') 
 hypothesis =  tensor(1, device='cuda:0') 




target =  tensor(0, device='cuda:0') 
 hypothesis =  tensor(0, device='cuda:0') 




target =  tensor(2, device='cuda:0') 
 hypothesis =  tensor(2, device='cuda:0') 




target =  tensor(5, device='cuda:0') 
 hypothesis =  tensor(5, device='cuda:0') 









In [None]:
# len(valid_dataset)

In [None]:
# valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=20, shuffle=False, num_workers=num_workerssz)

# with torch.no_grad(): # valid
#             model.eval()
#             for data, target in tqdm(valid_dataloader):
#                 data = rearrange(data, 'b d h w c -> b c d h w')
#                 data = data.to(device)
                
#                 target = target.to(device)
                
#                 hypothesis = model(data)

In [None]:
hypothesis.shape

torch.Size([4, 10])

## 사용 가능한 loss 함수

In [None]:
# class LabelSmoothingCrossEntropy(nn.Module):
#     def __init__(self):
#         super(LabelSmoothingCrossEntropy, self).__init__()
#     def forward(self, x, target, smoothing=0.1):
#         confidence = 1. - smoothing
#         logprobs = F.log_softmax(x, dim=-1)
#         nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
#         nll_loss = nll_loss.squeeze(1)
#         smooth_loss = -logprobs.mean(dim=-1)
#         loss = confidence * nll_loss + smoothing * smooth_loss
#         return loss.mean()