# Subtask 1

## Dataset

In [1]:
import numpy as np 
import copy #to copy data
import json #data가 json 형식으로 제공됨

import torch
from torch.utils.data import Dataset #store sample and label

from transformers import * 

In [2]:
#데이터셋 확인
data_path="hi/e_train.json"
dialog_data = json.load(open(data_path, 'r', encoding='utf-8')) #데이터 경로 지정받아서 불러오기

In [3]:
# special_tokens: 토큰화 과정에서 문장이 분할되지 않도록 spial token을 지정해줌
SPECIAL_TOKENS = ['[BOS]', '[EOS]', '[speaker1]', '[speaker2]', '[IMG]', '[TAG]', '[PAD]']
SPECIAL_TOKENS_DICT = {'bos_token':'[BOS]', 'eos_token':'[EOS]', 'additional_special_tokens':['[speaker1]', '[speaker2]', '[IMG]', '[TAG]'], 'pad_token':'[PAD]'}
SPECIAL_TOKENS_DICT
#BOS: Begin of Sentence
#EOS: End Of Sentence

{'bos_token': '[BOS]',
 'eos_token': '[EOS]',
 'additional_special_tokens': ['[speaker1]', '[speaker2]', '[IMG]', '[TAG]'],
 'pad_token': '[PAD]'}

In [4]:
#객체를 받아 토큰화시키는 함수 : str 일때, dictionary일때, 그 외일때 나눠서 
def tokenize(obj, tokenizer):
    if isinstance(obj, str): #obj가 srt 형식인지 아닌지
        return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
    if isinstance(obj, dict):
        return dict((n, tokenize(o, tokenizer)) for n,o in obj.items())
    return list(tokenize(o, tokenizer) for o in obj) 
# return=> token의 id 또는 List of id

In [5]:
# history와 answer data로 분리
def get_data(tokenizer, data_path, meme_feature_path):
    dialog_data = json.load(open(data_path, 'r', encoding='utf-8')) #데이터 경로 지정받아서 불러오기
    dialog_list = [] 
    for idx in dialog_data.keys(): #dialog_data.keys()=> 대화 각각 하나 하나 
        dialog = dialog_data[idx] #몇번째 대화
        history = [] 
        
        for i in range(len(dialog)): 
            if 'txt' in dialog[i].keys(): #각 대화에서 Txt가 있다면 토큰화해서 저장
                dialog[i]['txt'] = tokenize(dialog[i]['txt'], tokenizer) 
            if i == 0: 
                history.append(dialog[i]) 
                continue 
            pair = {'history': copy.deepcopy(history), 'answer': copy.deepcopy(dialog[i])}  #copy.deepcopy: 내부객체들까지 모두 새롭게 copy
            #history: 주고받은 대화들, answer: 마지막에 한 발화 
            #짝 지어서 새로 저장
            dialog_list.append(pair) 
            history.append(dialog[i]) 
        # break 
    id2feature = json.load(open(meme_feature_path, 'r', encoding='utf-8'))  #meme의 ID
    return dialog_list, id2feature 
#dialog_list: history와 answer로 나누어진 대화 데이터 

In [6]:
# build input type from data: 데이터로에서 input 유형 구축  
def build_input_from_segments(history, tokenizer, id2feature, answer=None): 
    bos, eos, speaker1, speaker2, img, tag =tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) 
    #SPECIAL_TOKENS[:-1] = ['[BOS]', '[EOS]', '[speaker1]', '[speaker2]', '[IMG]', '[TAG]']
    #PAD는 벡터 길이를 맞춰주기 위함이므로 굳이 id로 변환해주지 않아도 됨


    #history data는 현재 img_id와 emotion_id가 둘다 포함되어 있음: text와 meme이 함께 있는 상태 
    #-> 이를 분리해서 따로 넣어주자
    history_txt = [] #history 중 text 부분
    history_img = [] #history 중 image 부분
    labels = [] #bos eos 토큰 말고 추가적으로 나오는 토큰들의 value: IMG, TAG의 특정값ㅠ
    token_type_ids = [] #토큰화 유형 아이디
    
    # to avoid out of length, cut the partial sequence
    ans_len = 4 
    
    #대답이 존재하는데 그 대답에 text가 존재할 경우
    if answer is not None and 'txt' in answer.keys():
        ans_len += len(answer['txt']) #answer 중 텍스트 부분의 길이 누적합
    
    for i in range(len(history)-1, -1, -1): 
        # 길이 재기
        cur_len = 4 #초기값 4로 설정해둠(current length)
        if 'txt' in history[i].keys():
            cur_len += len(history[i]['txt']) #history 중 텍스트 부분의 길이 누적합
        if len(token_type_ids) + ans_len + cur_len > 500: 
            break #특정 길이를 넘어가면 멈추기
            
        #speaker id: 1,2만 존재하도록   
        if history[i]['speaker_id'] == '[speaker1]': 
            speaker_id = speaker1 
        else:
            speaker_id = speaker2 
         
        #history에 img있는 경우: history img로 저장
        if 'img_id' in history[i].keys(): 
            history_img = [id2feature[history[i]['img_id']]] + history_img
            token_type_ids = [img] + token_type_ids 
            labels = [-100] + labels 
        
        #history에 txt있는 경우: history text로 저장
        if 'txt' in history[i].keys(): 
            content = [bos] + history[i]['txt'] + [eos] #문장과 시작과 끝 포함
            history_txt = content + history_txt 
            token_type_ids = [speaker_id] * len(content) + token_type_ids 
            labels = [-100] * len(content) + labels 
        else: 
            content = [bos] + [eos] 
            history_txt = content + history_txt 
            token_type_ids = [speaker_id] * len(content) + token_type_ids 
            labels = [-100] * len(content) + labels 
    
        history_txt = [speaker_id] + history_txt 
        token_type_ids = [speaker_id] + token_type_ids 
        labels = [-100] + labels 
    
    #대답이 존재하는 경우 
    if answer is not None: 
        #발화자 지정
        if answer['speaker_id'] == '[speaker1]': 
            speaker_id = speaker1 
        else:
            speaker_id = speaker2 
    
        history_txt += [speaker_id] 
        token_type_ids += [speaker_id]
        
        if 'txt' in answer.keys(): 
            #첫시작하는 토큰 + 대답에서 text 부분 + 끝내는 토큰
            content = [bos] + answer['txt'] + [eos] 
            history_txt += content 
            token_type_ids += [speaker_id] * len(content) 
            labels += content 
        else: 
            content = [bos] + [eos] 
            history_txt += content 
            token_type_ids += [speaker_id] * len(content) 
            labels += content 
    
        labels += [-100, -100] 
        history_txt += [tag] 
        token_type_ids += [img] 
        
        #meme_flag: 0 또는 1로 구성 
        if 'img_id' in answer.keys(): 
            history_img += [id2feature[answer['img_id']]] 
            meme_flag = [1]
        else:
            history_img += [[0.0]*512] 
            meme_flag = [0] 
    return history_txt, history_img, token_type_ids, labels[1:], meme_flag #output

In [7]:
class MODDataset(Dataset): 
    def __init__(self, dialogs, id2feature, tokenizer): 
        self.dialogs = dialogs 
        self.id2feature = id2feature 
        self.tokenizer = tokenizer 
        #필요한 변수 선언: dialogs, id2feature, tokenizer=> 대화(text), 밈, 토큰화하기 위한 tokenizer
    
    def __len__(self):
        return len(self.dialogs) 
        #데이터셋의 샘플 개수 반환: len 함수 
    
    def __getitem__(self, index): 
      #__getitem__: 주어진 idxex 에 해당하는 샘플을 데이터셋에서 불러오고 반환해줌
    
      #get data 함수에서 history 와 answer로 나눴었음 
        his = copy.deepcopy(self.dialogs[index]['history'])  #history data
        ans = copy.deepcopy(self.dialogs[index]['answer'])  #answer data
        # print(his)
        # print(ans)
        
        #tensor 형태로 변환: 각 형태에 맞춰
        history_txt, histroy_img, token_type_ids, labels, meme_flag = build_input_from_segments(his, self.tokenizer, self.id2feature, ans) 
        history_txt = torch.LongTensor(history_txt) #64 bit integer
        histroy_img = torch.from_numpy(np.array(histroy_img)).float() 
        #numpy로 변경 후 float 형식으로 변경 
        token_type_ids = torch.Tensor(token_type_ids).long()
        labels = torch.Tensor(labels).long()
        meme_flag = torch.Tensor(meme_flag).long() 
        #torch.Tensor:tensor 값 변경하더라도 numpy 변화 x
        #torch.from_numpy:array의 dtype을 상속받고 tensor와 메모리 버퍼를 공유하기 때문에 tensor의 값이 변경되면 Numpy array값이 변경
        return history_txt, histroy_img, token_type_ids, labels, meme_flag #Output

In [8]:
if __name__ == '__main__': 
    data_path = 'hi/e_train.json' 
    meme_feature_path = 'hi/id2feature.json'
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', do_lower_case=True)
    #공백도 취급: 공백여부에 따라도 다르게 토큰화됨
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    dialog_list, id2feature = get_data(tokenizer, data_path, meme_feature_path) 
    dataset = MODDataset(dialog_list, id2feature, tokenizer) 
    history_txt, history_img, token_type_ids, labels, meme_flag = dataset[0]
    print(tokenizer.convert_ids_to_tokens(history_txt))
    print(history_img.size())
    print(tokenizer.convert_ids_to_tokens(token_type_ids))
    print(tokenizer.convert_ids_to_tokens(labels))

['[speaker1]', '[BOS]', 'H', 'aha', '.', 'Ġ.', 'Ġ.', 'ĠDon', "'t", 'Ġknow', 'Ġwho', 'Ġi', 'Ġam', '[EOS]', '[speaker2]', '[BOS]', 'O', 'uy', 'ang', '?', '[EOS]', '[TAG]']
torch.Size([2, 512])
['[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[speaker1]', '[IMG]', '[speaker2]', '[speaker2]', '[speaker2]', '[speaker2]', '[speaker2]', '[speaker2]', '[speaker2]', '[IMG]']
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, '[BOS]', 'O', 'uy', 'ang', '?', '[EOS]', None, None]


In [9]:
dataset[0]

(tensor([50259, 50257,    39, 12236,    13,   764,   764,  2094,   470,   760,
           508,  1312,   716, 50258, 50260, 50257,    46,  4669,   648,    30,
         50258, 50262]),
 tensor([[ 0.2078, -0.0766,  0.4716,  ...,  0.0949,  0.1990, -0.2000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50261, 50260, 50260, 50260, 50260, 50260,
         50260, 50260, 50261]),
 tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100, 50257,    46,  4669,   648,    30, 50258,
          -100,  -100]),
 tensor([0]))

Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Byte Pair Encoding (BPE)
gpt-2는 Byte Pair Encoding를 거친 토큰을 입력 단위로 사용합니다.

BPE는 서브워드를 분리하는 알고리즘으로, 빈도수에 따라 문자를 병합하여 서브워드를 구성합니다. 단어를 문자(char) 단위로 쪼갠 뒤, 가장 빈도수가 높은 쌍을 하나로 통합하는 과정을 반복하여 토큰 딕셔너리를 만듭니다.

앞으로 단어, 토큰이라고 불리는 것은 모두 BPE token을 의미합니다.



## Model

In [10]:
import torch 
from torch import nn 
from torch.nn import CrossEntropyLoss, MSELoss
import os 
from transformers import * 

class 형태의 모델은 항상 nn.Module 을 상속받아야 하며, super(모델명, self).__init__() 을 통해 nn.Module.__init__() 을 실행시키는 코드가 필요합니다.
forward() 는 모델이 학습데이터를 입력받아서 forward propagation을 진행시키는 함수이고, 반드시 forward 라는 이름의 함수이어야 합니다.


In [72]:
class MemeDialoGPT(GPT2PreTrainedModel): 
    def __init__(self, config): 
        #super: 자식클래스에서 부모클래스의 내용을 사용하고 싶은 경우, super().부모클래스내용, 여기서 부모클래스는 MODDataset
        #MODDataset에서 지정한 변수들 그대로 사용
        super(MemeDialoGPT, self).__init__(config) 
        self.transformer = GPT2Model(config) #GPT2 모델을 편의상 transformer로 정의 
        
        
        ## for text hidden state
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 
        # n_embd: Dimensionality of the embeddings and hidden states
        # vocab_size:the number of different tokens that can be represented by the inputs_ids passed when calling GPT2Model or TFGPT2Model.

        ## for image hidden state # E JARIE CNN GANUNG HAL DUT
        #self.img_ff = nn.Linear(512, config.n_embd)  
        #input: 512개의 픽셀 사용하는 이미지/ output: 임베딩 결과와 hidden state의 차원
        #self.img_ff = nn.Conv2d(2, 512, kernel_size=1)
        #self.img_ff = nn.Conv2d(2, 512, kernel_size=1)
        self.img_ff = nn.Conv2d(512, config.n_embd, kernel_size=1)
        #self.img_ff2 = nn.Linear()
        #, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
        #self.img_inverse_ff = nn.Conv2d(config.n_embd, 512, kernel_size=1)
        #self.img_inverse_ff = nn.Conv2d(512, 512, kernel_size=1)
        #, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
        self.img_inverse_ff = nn.Linear(config.n_embd, 512)
        #config.n_embd = 768
        
        #img_ff와 input, output 반대
        #GPT2모델은 masked-self attention 활용하는 자기 회귀 모델임: inverse 함수 필요
        
        # predict the meme usage 
        #사용 여부에 따라 0,1
        self.lm_flag = nn.Linear(config.n_embd, 2)

    def tie_weights(self):  #Tie the weights between the input embeddings and the output embeddings: 각각의 임베딩 결과물 벡터에는 가중치가 부여되어 있는데 input의 w와 output의 w 묶음.  
        self._tie_or_clone_weights(self.lm_head, self.transformer.wte) 
    
    def forward(self, input_embeds, token_type_ids, labels=None, img_feature=None, meme_flag=None): 
      #모델이 학습데이터를 입력받아서 forward propagation을 진행시키는 함수
      #input_embeds: input_construct 함수의 return 값
        #GPT2 모델에 임베딩한 input들을 넣음
        transformer_outputs = self.transformer(inputs_embeds=input_embeds, token_type_ids=token_type_ids) 
        
        #output값 중 일부를 hidden state로 지정해서 정보를 저장하고 넘김
        #Sequence of hidden-states at the output of the last layer of the model
        #text랑 image가 담는 정보가 다를 수밖에 없으므로 따로 저장
        hidden_states = transformer_outputs[0]
        txt_hidden_states, img_hidden_states = hidden_states[:-1, :], hidden_states[-1, :].unsqueeze(0) 
        print(txt_hidden_states.size())
        print(img_hidden_states.size())
        lm_logits = self.lm_head(txt_hidden_states) #hidden state를 어휘에 대한 확률 분포로 변환
        img_regs = self.img_inverse_ff(img_hidden_states) #픽셀마다 명암에 대한 밀도값 추정: 픽셀값 추정/ regression
        print(lm_logits.size())
        print(img_regs.size())
        
        outputs = (lm_logits,) + (img_regs, ) 
        
        #섭테 1,2,3 모두 활용 가능 
        if labels is not None:  #answer 존재
            txt_loss_fct = CrossEntropyLoss(ignore_index=-100) 
            loss = txt_loss_fct(lm_logits, labels)  
            #loss=CrossEntropyLoss(lm_logits, labels,ignore_index=-100) 

            #CrossEntropyLoss: 분류 문제처럼 확률값으로 나올때. meme이 있는 경우 쓸지 안쓸지 확률로 고려해야하기 때문에 사용해줌


            if meme_flag is not None: 
              #meme_flag가 존재:  answer에 meme 있는지 여부에 따라 0,1 나눴었음
                mf_logits = self.lm_flag(img_hidden_states)  #0,1 : binary classification
                mf_loss_fct = CrossEntropyLoss()
                mf_flag_loss = mf_loss_fct(mf_logits, meme_flag) 
#                 mf_flag_loss = CrossEntropyLoss(mf_logits, meme_flag) 
                loss += mf_flag_loss #위에 만든 loss에 더해짐
                outputs = (mf_logits,) + outputs

            if img_feature[0][0] != 0.:  #적잘한 Meme을 잘 썼는지 평가: 픽셀값 예측통해
                img_loss_fct = MSELoss() 
                loss += img_loss_fct(img_regs, img_feature) 
#                 meme 없는 경우: 
#                 loss += MSELoss(img_regs, img_feature) 
                #img_feature: def img_feature_read
            outputs = (loss,) + outputs 
        return outputs   

손실 함수란 신경망이 학습할 수 있도록 해주는 지표이다. 머신러닝 모델의 출력값과 사용자가 원하는 출력값의 차이, 즉 오차를 말한다. 이 손실 함수 값이 최소화되도록 하는 가중치와 편향을 찾는 것이 바로 학습이다. 일반적인 손실 함수로 평균 제곱 오차나 교차 엔트로피 오차를 사용

## train

In [127]:
# #각각 다른 파일에 있을 경우에는 class를 다시 불러와줘야함
# from model import MemeDialoGPT 
# from dataset import MODDataset, get_data 
# from utils import accuracy_compute, AverageMeter, meme_classify_accuracy

In [128]:
SPECIAL_TOKENS = ['[BOS]', '[EOS]', '[speaker1]', '[speaker2]', '[IMG]', '[TAG]', '[PAD]']
SPECIAL_TOKENS_DICT = {'bos_token':'[BOS]', 'eos_token':'[EOS]', 'additional_special_tokens':['[speaker1]', '[speaker2]', '[IMG]', '[TAG]'], 'pad_token':'[PAD]'}

# data parameters
train_data_path = 'hi/e_train.json'
val_data_path = 'hi/e_validation.json' 
feature_path = 'hi/id2feature.json'


# model parameters
use_cuda = torch.cuda.is_available() 
device = torch.device('cuda' if use_cuda else 'cpu') 
model_path = 'hi'
gpt_path = 'gpt2' 
ckpt_usage = False
lr = 6e-5
epochs = 1
gradient_accumulation_steps = 1
print_freq = 1

In [129]:
 # concatenate the input 
def input_construct(history_txt_embs, history_img_embs, token_type_ids, tokenizer): 
        ## in train fuction:이미지와 텍스트 따로 임베딩해줬음
        #history_txt_embs = model.transformer.wte(history_txt) 
        #history_img_embs = model.img_ff(history_img) 
    bos, eos, speaker1, speaker2, img, tag = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) 
    emb_length = token_type_ids.size(-1) 
    emb_dim = history_txt_embs.size(-1)
    img_num = history_img_embs.size(0) 

    #0으로만 이루어진 텐서하나 만들고 임베딩할 길이와 차원
    input_embs = torch.zeros((emb_length, emb_dim)).to(device)

    #임베딩 결과를 정리
    #순서가 현재 meme
    txt_idx = 0 
    img_idx = 0 
    left_idx  = 0 
    right_idx = 0 
    while right_idx < emb_length: 
        #if right_idx == emb_length-1 and token_type_ids[right_idx] == img: 
        #    break 

        #right index가 embeding length까지 1씩 추가되면서 반복문 돌던 중 meme이 나오면
        #거기까지 text, 그 이후 meme
        if right_idx < emb_length-1 and token_type_ids[right_idx] == img:
            txt_length = right_idx - left_idx 
            input_embs[left_idx:right_idx, :] = history_txt_embs[txt_idx:txt_idx+txt_length, :] 
            txt_idx += txt_length 
            input_embs[right_idx,:] = history_img_embs[img_idx, :] 
            img_idx += 1
            left_idx = right_idx + 1 
        right_idx += 1
    txt_length = right_idx - left_idx 
    if txt_length > 0: 
        input_embs[left_idx:right_idx, :] = history_txt_embs[txt_idx:, :]
    # img_feature = history_img_embs[img_idx,:] 
    return input_embs

In [130]:
#meme 이미지 불러옴
def img_feature_read(feature_path): 
    with open(feature_path, 'r', encoding='utf-8') as f: 
        id2feature_dict = json.load(f) 
    img_features = [] 
    for id in id2feature_dict.keys(): #key: 이미지 번호
        img_features.append(id2feature_dict[id]) 
        #제공해준 meme 모음 파일과 비교하면서 대화 속 meme list 저장
    img_features = np.array(img_features) 
    img_features = torch.from_numpy(img_features).float().to(device)
    #tensor 형태로 저장
    print(img_features)
    return img_features

In [131]:
def meme_retrieval_compute(cur_img_feature, target_img_feature, cat_img_features): 
    # (1, 512)
    #현재 이미지 피쳐와 타겟 이미지 피쳐 간 거리 계산
    cur_dist = torch.dist(cur_img_feature, target_img_feature, p=2)
    # print(cat_img_features.size())
    #307개의 이미지
    cur_img_list = cur_img_feature.repeat(307,1) #img_regs
    #오차의 제곱합의 루트: rmse
    total_dist = torch.sqrt(torch.sum((cur_img_list - cat_img_features)**2, dim=1))
    # print(total_dist) 
    sorted_total, _ = torch.sort(total_dist) 
    # print(sorted_total) 
    return torch.gt(sorted_total[90],cur_dist)

In [132]:
def train(model, tokenizer, optimizer, dataset, epoch): 
    model.train() 
    cat_img_features = img_feature_read(feature_path) 
    avg_loss = AverageMeter() 
    avg_acc = AverageMeter() 
    iteration = 1
    meme_correct_num = 0 
    meme_total_num = 0

    for instance in dataset: 
        history_txt, history_img, token_type_ids, labels, meme_flag = instance 
        history_txt, history_img, token_type_ids, labels, meme_flag = history_txt.to(device).squeeze(0), history_img.to(device).squeeze(0), \
                                                                        token_type_ids.to(device).squeeze(0), labels.to(device).squeeze(0), meme_flag.to(device).squeeze(0)   
        history_txt_embs = model.transformer.wte(history_txt) 
        #print(history_txt_embs.size()
        print(history_img.size())
        #history_img1 = history_img.transpose(0, 1)
        #print(history_img1.size())
        history_img2 = history_img.unsqueeze(-1)
        print(history_img2.size())
        history_img3 = history_img2.unsqueeze(-1)
        print(history_img3.size())
        history_img_embs = model.img_ff(history_img3)
        print(history_img_embs.size())
        #history_img_embs = history_img_embs.view(history_img_embs.shape[0], -1)
        history_img_embs = history_img_embs.view(2, -1)
        print(history_img_embs.size()) 
        #print(token_type_ids) 
        #print(history_txt)
        input_embs = input_construct(history_txt_embs, history_img_embs, token_type_ids, tokenizer) 
        input_embs = input_embs.to(device) 
        img_feature = history_img[-1, :].unsqueeze(0)
        # print(input_embs.size()) 
        # print(img_feature.size()) 
        loss, mf_logits, lm_logits, cur_img_feature = model(input_embs, token_type_ids, labels, img_feature, meme_flag) 
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 

        if iteration % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            #meme이 있는 경우 적절하게 사용되었지만 파악함
        if img_feature[0][0] != 0.: 
            if meme_retrieval_compute(cur_img_feature, img_feature, cat_img_features):
                meme_correct_num += 1 
            meme_total_num += 1 
        acc = accuracy_compute(lm_logits, labels, 5)
        #잘 분류되었는지 체크
        #acc = meme_classify_accuracy(mf_logits, meme_flag).item()
        avg_acc.update(acc)
        
        avg_loss.update(loss.item())
        
        # print status 
#        if iteration % print_freq == 0:
#             print('Epoch:[{0}][{1}/{2}]\t'
#             'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
#             'Classify Acc {acc.val:.3f} ({acc.avg:.3f})\t'
#             'Meme Acc {mac:.3f}'.format(epoch, iteration, len(dataset),loss=avg_loss, acc=avg_acc, mac=float(meme_correct_num/meme_total_num)))
        if iteration % print_freq == 0:
            print('Epoch:[{0}][{1}/{2}]\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            'Classify Acc {acc.val:.3f} ({acc.avg:.3f})\t'.format(epoch, iteration, len(dataset),loss=avg_loss, acc=avg_acc))
        print(lm_logits)
        print(labels)
        print(acc)
        print(lm_logits.size())
        print(labels.size())
        
        
        iteration += 1
        break
    return avg_loss.avg
        # print(loss)
        # break 

In [133]:
def validate(model, tokenizer, dataset, epoch): 
    
    model.eval() 
    avg_loss = AverageMeter() 
    avg_acc = AverageMeter() 
    avg_bleu = AverageMeter() 
    iteration = 1 
    cat_img_features = img_feature_read(feature_path) 
    meme_correct_num = 0 
    meme_total_num = 0

    with torch.no_grad(): 
        for instance in dataset: 
            history_txt, history_img, token_type_ids, labels, meme_flag = instance 
            history_txt, history_img, token_type_ids, labels, meme_flag  = history_txt.to(device).squeeze(0), history_img.to(device).squeeze(0), \
                                                                            token_type_ids.to(device).squeeze(0), labels.to(device).squeeze(0), meme_flag.to(device).squeeze(0) 
            history_txt_embs = model.transformer.wte(history_txt)
            #history_img1 = history_img.transpose(0, 1)
            #print(history_img1.size())
            history_img2 = history_img.unsqueeze(-1)
            print(history_img2.size())
            history_img3 = history_img2.unsqueeze(-1)
            print(history_img3.size())
            history_img_embs = model.img_ff(history_img3)
            print(history_img_embs.size())
        #history_img_embs = history_img_embs.view(history_img_embs.shape[0], -1)
            history_img_embs = history_img_embs.view(2, -1)
            #print(history_img_embs.size())
            #history_img_embs = model.img_ff2(history_img_embs)
            #print(history_img_embs.size())
            #history_img_embs = model.img_ff_lin
            
            
            #history_img_embs = model.img_ff(history_img) 
            
            input_embs = input_construct(history_txt_embs, history_img_embs, token_type_ids, tokenizer) 
            input_embs = input_embs.to(device) 
            if input_embs.size(-2) > 450:
                continue
            img_feature = history_img[-1, :].unsqueeze(0) 
            loss, mf_logits, lm_logits, cur_img_feature = model(input_embs, token_type_ids, labels, img_feature, meme_flag) 
            # compare cur_img_feature is among topk with img_feature 
            # print(cur_img_feature.size())   (1, 512) 
            if img_feature[0][0] != 0.: 
                if meme_retrieval_compute(cur_img_feature, img_feature, cat_img_features):
                    meme_correct_num += 1 
                meme_total_num += 1 
            acc = accuracy_compute(lm_logits, labels, k=5) 
            #acc = meme_classify_accuracy(mf_logits, meme_flag).item()
            avg_acc.update(acc) 
            avg_loss.update(loss.item()) 
            if iteration % print_freq == 0:
                print('Epoch:[{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Classify Acc {acc.val:.3f} ({acc.avg:.3f})\t'.format(epoch, iteration, len(dataset),loss=avg_loss, acc=avg_acc))

#                 print('Epoch:[{0}][{1}/{2}]\t'
#                 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
#                 'Acc {acc.val:.3f} ({acc.avg:.3f})\t'
#                 'Meme Acc {mac:.3f}'.format(epoch, iteration, len(dataset),loss=avg_loss, acc=avg_acc, mac=float(meme_correct_num/meme_total_num))) 
        
            iteration += 1 
            break

In [134]:
#메인 함수 
def main(): 
    
    # model initialize  #모델  초기화
    if ckpt_usage == True: 
        ckpt_path = 'gpt2' 
         

    else:
        tokenizer = GPT2Tokenizer.from_pretrained(gpt_path, do_lower_case=True) #gpt 모델 사용 
        model = MemeDialoGPT.from_pretrained(gpt_path)
        tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) 
        model.resize_token_embeddings(len(tokenizer))
    model = model.to(device) 
    optimizer = AdamW(model.parameters(), lr=lr)

    # data read : train, validation
    #데이터 불러와서 로드까지!
    train_dialogs, id2feature = get_data(tokenizer, train_data_path, feature_path) #여기서 왜 자꾸 오류? : 불러오는데 문제 있?
    val_dialogs, _ = get_data(tokenizer, val_data_path, feature_path) 
    #print(len(train_dialogs))
    train_dataset = MODDataset(train_dialogs, id2feature, tokenizer) 
    val_dataset = MODDataset(val_dialogs, id2feature, tokenizer) 
    # print(len(train_dataset))
    train_loader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=1, num_workers=4, pin_memory=True) 
    
    for epoch in range(epochs): 
        
        # one epoch's training
        val_loss = train(model=model, tokenizer=tokenizer, optimizer=optimizer, dataset=train_loader, epoch=epoch) 
        
        # one epoch's validation 
        validate(model=model, tokenizer=tokenizer, dataset=val_loader, epoch=epoch)
        
        #break 트라이1
        # save checkpoint 
        torch.save({'model':model.state_dict(), 'optimizer': optimizer.state_dict()},\
            '%s/epoch_%d_loss_%.3f'%(model_path, epoch, val_loss))
        model.config.to_json_file(os.path.join(model_path, 'config.json'))
        tokenizer.save_vocabulary(model_path)

In [135]:
from transformers import GPT2Tokenizer, GPT2Model

## utils

In [136]:
import torch 

# calculate the accuracy of response  
def accuracy_compute(lm_logits, targets, k=5):
    _, idx = torch.topk(lm_logits, k, 1)
    correct = idx.eq(targets.view(-1,1).expand_as(idx))
    correct_total = correct.view(-1).float().sum().item()
    nums = targets.view(-1).detach().cpu().numpy()
    length = 0
    for num in nums:
        if num != -100:
            length += 1
            
    return correct_total / float(length)

#지금 correct_total이 [x, y] 식으로 출력되고 length

In [137]:
def meme_classify_accuracy(mf_logits, meme_flag):
    prediction = torch.argmax(mf_logits, 1) 
    return (prediction == meme_flag).sum()

In [138]:
# class for evaluation metric 
class AverageMeter(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [139]:
from torch.utils.data import Dataset, DataLoader

In [140]:
#함수 시행
if __name__ == '__main__': 
    main()

Some weights of MemeDialoGPT were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.7.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.1.attn.masked_bias', 'h.11.attn.masked_bias', 'img_inverse_ff.bias', 'lm_flag.bias', 'h.6.attn.masked_bias', 'h.9.attn.masked_bias', 'lm_head.weight', 'h.3.attn.masked_bias', 'img_inverse_ff.weight', 'h.8.attn.masked_bias', 'lm_flag.weight', 'img_ff.bias', 'h.10.attn.masked_bias', 'img_ff.weight', 'h.0.attn.masked_bias', 'h.2.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([[-0.2045, -0.3605,  0.1353,  ...,  0.2115, -0.2773, -0.4641],
        [-0.0759, -0.0538,  0.1329,  ...,  0.6213, -0.1812, -0.1301],
        [-0.1741, -0.1418,  0.1408,  ...,  0.5004, -0.1981, -0.4364],
        ...,
        [-0.1725, -0.2610, -0.2571,  ...,  0.7686, -0.2020, -0.2577],
        [-0.0091, -0.2522, -0.0359,  ...,  0.3179, -0.1381, -0.0518],
        [-0.3150, -0.2858,  0.2325,  ...,  0.2793,  0.1454, -0.2166]],
       device='cuda:0')
torch.Size([2, 512])
torch.Size([2, 512, 1])
torch.Size([2, 512, 1, 1])
torch.Size([2, 768, 1, 1])
torch.Size([2, 768])
torch.Size([22, 768])
torch.Size([1, 768])
torch.Size([22, 50264])
torch.Size([1, 512])
Epoch:[0][1/859685]	Loss 53.6130 (53.6130)	Classify Acc 0.333 (0.333)	
tensor([[-1.2757e+01, -1.2444e+01, -1.4971e+01,  ...,  4.5236e-01,
          1.2282e-02, -3.4585e-01],
        [-9.3402e+01, -9.2599e+01, -1.0109e+02,  ...,  1.9279e+00,
         -2.5267e-02, -1.1982e-01],
        [-8.4746e+01, -8.4042e+01, -8.6261e+01,  ...,  1.

In [141]:
# 위 칸에 출력되는게 train score고,
# 아래칸에 출력되는게 validation score같음...

# 근데 지금보면 데이터수가 859685 : 12666 으로 거의 68:1 인거 같음.
# validation data갯수가 작네?