In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import copy
import pickle
import numpy as np
import time
import sys

OSError: [WinError 126] 지정된 모듈을 찾을 수 없습니다. Error loading "C:\Users\USER\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\torch\lib\shm.dll" or one of its dependencies.

In [10]:
from ipynb.fs.full. my_tk import My_tokenizer
from ipynb.fs.full. config import Config
from ipynb.fs.full. utils_data import DLoader
from ipynb.fs.full. utils_func import *
from ipynb.fs.full. model_ino import *

In [11]:
class Trainer:
    def __init__(self, config:Config, device:torch.device, mode:str, continuous:int):
        self.config = config
        self.device = device
        self.mode = mode
        self.continuous = continuous
        self.dataloaders = {}

        self.model = {}
        self.loss_data = {}

        # if continuous, load previous training info
        if self.continuous:
            with open(self.config.loss_data_path, 'rb') as f:
                self.loss_data = pickle.load(f)

        # path, data params
        self.base_path = self.config.base_path
        self.model_path = self.config.model_path

        # train params
        self.batch_size = self.config.batch_size
        self.epochs = self.config.epochs
        self.enc_lr = self.config.enc_lr
        self.dec_lr = self.config.dec_lr

        # model params
        self.img_size = self.config.img_size
        self.max_len = self.config.max_len

        # for reproducibility
        torch.manual_seed(999)

        # set transforms (ImageNet mean, std because pre-trained ResNet101 trained by ImageNet)
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        self.trans = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)), 
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

        # make dataset
        self.img_folder = self.base_path + 'data/Images/'
        self.caption_file = self.base_path + 'data/captions.txt'
        self.all_pairs = collect_all_pairs(self.caption_file)
        
        # 편의상 400개까지
        self.all_pairs = self.all_pairs[:74]
        self.trainset_id, self.valset_id = make_dataset_ids(len(self.all_pairs), 10)
        self.tokenizer = My_tokenizer(self.config, self.all_pairs, self.trainset_id)

        # train set
        if self.mode == 'train':
            self.trainset = DLoader(self.img_folder, self.all_pairs, self.trans, self.trainset_id, self.tokenizer, self.max_len)
            self.dataloaders['train'] = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=4)
        
        # val set
        self.valset = DLoader(self.img_folder, self.all_pairs, self.trans, self.valset_id, self.tokenizer, self.max_len)
        self.dataloaders['test'] = DataLoader(self.valset, batch_size=self.batch_size, shuffle=False, num_workers=4)

        # model, optimizer, loss
        self.encoder = Encoder(self.config).to(self.device)
        self.decoder = Decoder(self.config, self.tokenizer).to(self.device)
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)

        # 훈련시킬 경우 -> 옵티마이저, 러닝레이트 필요
        if self.mode == 'train':
            # encoder, decoder optimizer 설정
            self.enc_optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, self.encoder.parameters()), lr=self.enc_lr)
            self.dec_optimizer = optim.Adam(self.decoder.parameters(), lr=self.dec_lr)

            # 이어서 학습할 경우
            if self.continuous:
                self.check_point = torch.load(self.model_path, map_location=self.device)
                self.encoder.load_state_dict(self.check_point['model']['encoder'])
                self.decoder.load_state_dict(self.check_point['model']['decoder'])
                self.enc_optimizer.load_state_dict(self.check_point['optimizer']['encoder'])
                self.dec_optimizer.load_state_dict(self.check_point['optimizer']['decoder'])
                del self.check_point
                torch.cuda.empty_cache()
        
        # 테스트, 추론할 경우
        elif self.mode == 'test' or self.mode == 'inference':
            self.trans4attn = transforms.Compose([
                transforms.Resize((252, 252)),
                transforms.ToTensor()])
            self.check_point = torch.load(self.model_path, map_location=self.device)
            self.encoder.load_state_dict(self.check_point['model']['encoder'])
            self.decoder.load_state_dict(self.check_point['model']['decoder'])
            self.encoder.eval()
            self.decoder.eval()
            del self.check_point
            torch.cuda.empty_cache()

In [12]:
def model_train(Trainer):
    early_stop = 0

    # 사전학습 유무
    best_val_bleu = 0 if not Trainer.continuous else Trainer.loss_data['best_val_bleu']
    train_loss_history = [] if not Trainer.continuous else Trainer.loss_data['train_loss_history']
    val_loss_history = [] if not Trainer.continuous else Trainer.loss_data['val_loss_history']
    val_score_history = {'bleu2': [], 'bleu4': [], 'nist2': [], 'nist4': [], 'topk_acc': []} if not Trainer.continuous else Trainer.loss_data['val_score_history']
    best_epoch_info = 0 if not Trainer.continuous else Trainer.loss_data['best_epoch']

    for epoch in range(Trainer.epochs):
        print('*'*10, 'epoch start', '*'*10)
        start = time.time() # epoch 시작
        print(epoch + 1, '/', Trainer.epochs)
        print('-'*10)

        for phase in ['train', 'test']:
            print('Phase: {}'.format(phase))

            if phase == 'train':    # train인 경우
                Trainer.encoder.train()
                Trainer.decoder.train()
            else:                   # test인 경우
                Trainer.encoder.eval()
                Trainer.decoder.eval()

            total_loss, total_acc = 0, 0
            all_val_trg, all_val_output = [], []

            print('-'*20, 'enumer start', '-'*20)
            for i, (img, cap, _) in enumerate(Trainer.dataloaders[phase]):

                batch_size = img.size(0)
                #print('in train_model, line 32')
                #print('shape = ', img.shape)
                
                img, cap = img.to(Trainer.device), cap.to(Trainer.device)
                Trainer.enc_optimizer.zero_grad()
                Trainer.dec_optimizer.zero_grad()

                # train이면 grad 변화
                with torch.set_grad_enabled(phase=='train'):
                    enc_output, hidden = Trainer.encoder(img)
                    
                    decoder_all_output, decoder_all_score = [], []

                    #print('line 45, max_len = ', Trainer.max_len)
                    
                    for j in range(Trainer.max_len):
                        trg_word = cap[:, j].unsqueeze(1)
                        dec_output, hidden, score = Trainer.decoder(trg_word, hidden, enc_output)
                        #print('decoding')
                        decoder_all_output.append(dec_output)
                        #print('append')
                        
                        # Attention layer면
                        if Trainer.config.is_attn:
                            #print('into append')
                            decoder_all_score.append(score)
                            #print('if append')

                    decoder_all_output = torch.cat(decoder_all_output, dim= 1)
                    print('line 60, decoder all output', decoder_all_output.shape)

                    loss = Trainer.criterion(decoder_all_output[:, :-1, :].reshape(-1, decoder_all_output.size(-1)), cap[:, 1:].reshape(-1))
                    print('line 63, loss is', loss)

                    if Trainer.config.is_attn:
                        decoder_all_score = torch.cat(decoder_all_score, dim=2)
                        loss += Trainer.config.regularization_lambda * ((1. - torch.sum(decoder_all_score, dim=2)) ** 2).mean()
                    
                    acc = topk_accuracy(decoder_all_output[:, :-1, :], cap[:, 1:], Trainer.config.topk, Trainer.tokenizer.eos_token_id)
                    print('line 71, acc', acc)
                    
                    if phase == 'train':
                        loss.backward()
                        Trainer.enc_optimizer.step()
                        Trainer.dec_optimizer.step()
                    else:
                        all_val_trg.append(cap.detach().cpu())
                        all_val_output.append(decoder_all_output.detach().cpu())

                total_loss += loss.item()*batch_size
                total_acc += acc * batch_size
                print(f'total_loss is {total_loss}, total_acc is {total_acc}')

                if i % 100 == 0:
                    print('Epoch {}: {}/{} step loss: {}, top-{} acc: {}'.format(epoch+1, i, len(Trainer.dataloaders[phase]), loss.item(), Trainer.config.topk, acc))
                
            print('-'*20, 'enumer end', '-'*20)

            epoch_loss = total_loss/len(Trainer.dataloaders[phase].dataset)
            epoch_acc = total_acc/len(Trainer.dataloaders[phase].dataset)

            print(f'{phase} loss: {epoch_loss}, top-{Trainer.config.topk} acc: {epoch_acc}\n')

            if phase == 'train':
                train_loss_history.append(epoch_loss)
            if phase == 'test':
                val_loss_history.append(epoch_loss)

                # print examples
                print_samples(cap, decoder_all_output, Trainer.tokenizer)
                print('line 100, print sample done..')

                # calculate scores
                all_val_trg, all_val_output = tensor2list(all_val_trg, all_val_output, Trainer.tokenizer)
                print('tensor2list done')
                print(type(all_val_trg))
                print(all_val_trg)
                # val_score_history['bleu2'].append(cal_scores(all_val_trg, all_val_output, 'bleu', 2))
                print('blew2 append done')
                # val_score_history['bleu4'].append(cal_scores(all_val_trg, all_val_output, 'bleu', 4))
                print('blew4 append done')
                val_score_history['nist2'].append(cal_scores(all_val_trg, all_val_output, 'nist', 2))
                print('nist2 append done')
                val_score_history['nist4'].append(cal_scores(all_val_trg, all_val_output, 'nist', 4))
                print('nist4 append done')
                val_score_history['topk_acc'].append(epoch_acc)
                print('topk acc append done')

                print('bleu2: {}, bleu4: {}, nist2: {}, nist4: {}'.format(val_score_history['bleu2'][-1], val_score_history['bleu4'][-1], val_score_history['nist2'][-1], val_score_history['nist4'][-1]))
                print('line 110, cal scores done')
                    
                # save best model
                early_stop += 1

                if best_val_bleu < val_score_history['bleu4'][-1]: # 마지막 score가 더 높으면
                    early_stop = 0
                    # best_val 변경
                    best_val_bleu = val_score_history['bleu4'][-1]
                    # encoder, decorder weights 변경
                    best_enc_wts = copy.deepcopy(Trainer.encoder.state_dict())
                    best_dec_wts = copy.deepcopy(Trainer.decoder.state_dict())
                    # best epoch 변경
                    best_epoch = best_epoch_info + epoch + 1
                    # 저장
                    save_checkpoint(Trainer.model_path, [Trainer.encoder, Trainer.decoder], [Trainer.enc_optimizer, Trainer.dec_optimizer])
        print('*'*10, 'epoch end', '*'*10)
                    
        print("time: {} s\n".format(time.time() - start)) # epoch 끝
        print('\n'*2)

        # early stopping
        if early_stop == Trainer.config.early_stop_criterion:
                break

    # best 값 출력
    print('best val bleu: {:4f}, best epoch: {:d}\n'.format(best_val_bleu, best_epoch))

    # model을 딕셔너리로
    Trainer.model = {'encoder': Trainer.encoder.load_state_dict(best_enc_wts), 'decoder': Trainer.decoder.load_state_dict(best_dec_wts)}
    Trainer.loss_data = {'best_epoch': best_epoch, 'best_val_bleu': best_val_bleu, 'train_loss_history': train_loss_history, 'val_loss_history': val_loss_history, 'val_score_history': val_score_history}

    return Trainer.model, Trainer.loss_data