In [2]:
#全局变量
import torch

# Data Information


TRAINING_BATCH_SIZE = 64

TESTING_BATCH_SIZE = 1024

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ENCODER_DIM = 512 #由图片的最后一个维度决定

DECODER_DIM = 512

ATTENTION_DIM = 256

LEARNING_RATE = 3e-4

EMBED_DIM = 200


In [3]:
#导入字典
import spacy

#Using the fastest word splitting tools
SPACY_OBJ = spacy.load("en_core_web_sm")

#Defining our dictionary for all the words contained in the provided captions
class MyVocab:
    """
        This class is responsible for constructing the dictionary which contains
        all the words that appear over a certain frequency, which we will use to
        tokenize any given sentence for our RNN model.

    """
    def __init__(self):
        #Pre restore the tokens mapping.
        """
        These are severals pre-defined tokens to pre process the sentence(sequqnce).
        <PAD>: Used to pad any given sentence to a uniform length, making it easier for
        RNN model to handle.
        <SOS>: Inserted at the start of each sentence.
        <EOS>: Appended at the end of each sentence.
        <UNK>: Mark the word that hasn't appeared in the captions in the training data.
        """
        self.index_to_tokens = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}

        #Inverse the above dictionary
        self.tokens_to_index = {value:key for key,value in self.index_to_tokens.items()}



    def __len__(self):
        """
        :return: int The number of the stored tokens
        """
        return len(self.index_to_tokens)


    def build_vocab(self,sentence_list,min_count=1,max_count=None,max_features=None):
        """
        This function builds the dictionary for RNN model
        :param sentence_list: An iterable containers that includes all the sentences
        :param min_count: The minimum number of the time that a word should appear in all the sentences.
        :param max_count: The maximum number of the time that a word should appear in all the sentences.
        :param max_features: Number of words to keep(From the most frequent words).
        :return:
        """

        #Create a dictionary for counting word frequency
        self.frequency_counter = {}

        #Create word_dict from several sentences
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                self.frequency_counter[word] = self.frequency_counter.get(word,0)+1

        #Filtering
        if min_count is not None:
            self.frequency_counter = {word:value for word,value in self.frequency_counter.items() if value >= min_count}

        if max_count is not None:
            self.frequency_counter = {word:value for word,value in self.frequency_counter.items() if value <= max_count}

        if max_features is not None:
            self.frequency_counter = dict(list(sorted(self.frequency_counter.items(),key=lambda x:x[-1],reverse=True))[:max_features])

        #Creating words_to_index mapping
        for word in self.frequency_counter:
            self.tokens_to_index[word] = len(self.tokens_to_index)


        #Creating index_to_words mapping
        self.index_to_tokens = dict(zip(self.tokens_to_index.values(),self.tokens_to_index.keys()))



    def sentence_to_index(self,sentence,max_len = 20):
        """
        This function converts the sentence to word index and controls
        the maximum length of the sentence. Meanwhile, it adds <SOS> and <EOS> tags
        to the beginning and the ending of a given sentence.
        :param sentence: string, A sentence in string.
        :param max_len: int, performing sentence pruning
        :return:
        """
        tokenized_sentence = self.tokenize(sentence)
        if max_len is not None:
            tokenized_sentence = tokenized_sentence[:max_len]

        return [self.tokens_to_index.get(word,self.tokens_to_index["<UNK>"]) for word in tokenized_sentence]


    def index_to_sentence(self,indices):
        """
        This function converts the index back to caption words, used for visualization.
        :param indices: A list of the index of words, e.g. [2,3,6,9,10,...]
        :return: A list of word corresponding to the indices, ["today","good","date",...]
        """
        return [self.index_to_tokens.get(index) for index in indices]



    @staticmethod
    def tokenize(content):
        # Filter out the unwanted signs(Unlikely to be seen in a generated sentence.
        tokens = [token.text.lower() for token in SPACY_OBJ.tokenizer(content)]
        return tokens

ModuleNotFoundError: No module named 'spacy'

In [4]:
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import Compose,ToTensor,Normalize,Resize,RandomCrop
import os
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image


#Doing Data Processing
class Flickr8K(Dataset):
    def __init__(self,root_path,captions_path,transform=None,train = True):
        """
        This class create a self-defined dataset, which integrates image
        preprocessing and captions preprocessing.

        :param root_path: root file location for the data, for a more compatible transplant between group members
        :param captions_path: relative file location of captions.txt, and processed for RNN NetWork
        :param transform: Transform the image data to fit better in the CNN Network
        """

        # Predefine the some basic attribute
        self.root_path = root_path
        self.pandas_dataframe = pd.read_csv(captions_path)
        self.transform_img = transform


        # Get the images and captions
        self.imgs_list = self.pandas_dataframe["image"]
        self.captions_list = self.pandas_dataframe["caption"]


        # Getting the word_dictionary created from the input captions
        self.word_dictionary = MyVocab()
        self.word_dictionary.build_vocab(list(self.captions_list))



    def __getitem__(self,index):
        """
        Getting the single entry of the data, here in our project, we are going
        to use img as input, and caption as target label. Reference from
        torchvision.datasets.MNIST()

        :param index: int.  The index of the entry.
        :return: tuple. A single entry of the data. (img,caption)
        img: tensor object ([channels,width,height])
        caption_to_index: tensor object [3,2,4,5,1,2,....,4]
        """

        # Get the name of the image. 10...12.jpg
        img_name = self.imgs_list[index]
        # Get the path of the image.
        img_path = os.path.join(self.root_path,img_name)
        # Reference from MINST class, Implement the img_loader
        img = Image.open(img_path).convert("RGB")


        #Transform the image
        if self.transform_img:
            img = self.transform_img(img)


        # Process the captions
        caption = self.captions_list[index]

        caption_to_index = [self.word_dictionary.tokens_to_index["<SOS>"],
                            *self.word_dictionary.sentence_to_index(caption),
                            self.word_dictionary.tokens_to_index["<EOS>"]]

        return (img,torch.tensor(caption_to_index))


    def __len__(self):
        """
        :return: Number of the entries of the imgs_captions dataframe
        """
        return len(self.pandas_dataframe)



class ProcessCaption:
    """
    This class tries to fix the ingrained problem in DataLoader
    Class, where it will mistakenly recognize img_pixel as container_abcs.Sequence
    data type, which will result in the recursive pruning of the
    caption index list.
    """
    def __init__(self,pad_idx,batch_first = False):
        self.pad_idx = pad_idx
        self.batch_fist = batch_first

    def __call__(self,batch):
        """

        :param batch: An entry of data
        :return: tuple: (tensor,tensor)
        img_pixel: tensor.Size([batch_size,channels,height,width])
        caption: tensor.Size([batch_size,seq_len])
        """

        # Adding an axis at the batch_size dimension
        img_pixel = [item[0].unsqueeze(0) for item in batch]

        # Concating the tensor at the batch_size dimension
        img_pixel = torch.cat(img_pixel, dim=0)


        caption = [item[1] for item in batch]
        caption = torch.nn.utils.rnn.pad_sequence(caption, batch_first=self.batch_fist,
                                                  padding_value=self.pad_idx)

        return img_pixel, caption


DATA_PATH = "C:/Users/DELL/Desktop/NYU/NYU Class Materials/NYUSpring2021Courses/Machine Learning/Final Project/数据集"

# Used to transform the original Images
pre_transform = Compose([Resize((224,224)),
                         RandomCrop(224),
                         ToTensor(),
                         Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
                         ])

dataset = Flickr8K(root_path=DATA_PATH + "/Images",
                       captions_path=DATA_PATH + "/captions.txt",
                       transform=pre_transform,
                       train = True
                       )

VOCAB_DIM = len(dataset.word_dictionary)

PAD_IDX = dataset.word_dictionary.tokens_to_index["<PAD>"]


def get_data_loader(train=True):
    """
    This function generate the dataset needed for either training set
    or testing set
    :param train:
    :return: A dataloader object(iterable)
    """
    if train:
        batch_size = TRAINING_BATCH_SIZE
    else:
        batch_size = TESTING_BATCH_SIZE

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             collate_fn=ProcessCaption(pad_idx=PAD_IDX,batch_first=True))

    return data_loader


def show_image(img_pixel,label = None):
    """
    This is an auxiliary function for
    :param img_pixel: tensor object [channels,width,height]
    """

    # Reshape the image for imshow => [width,height,channels]
    img_pixel = img_pixel.numpy().transpose((1,2,0))
    plt.imshow(img_pixel)
    if label:
        plt.title(label)
    plt.pause(0.001)
    plt.show()


ModuleNotFoundError: No module named 'torchvision'

In [None]:
#CNN
import torch.nn as nn
import torchvision.models as models


class MyEncoderCNN(nn.Module):
    def __init__(self,cnn_type="resnet"):
        """
        Here we are going to use the pretrained resnet50,vggnet19 model, which is capsulated in
        torchvision.models. This model is trained on mindspore dataset, which can extract
        the features in a given picture.
        """
        super().__init__()

        assert cnn_type in ["resnet","googlenet","vggnet"]
        self.cnn_type = cnn_type

        # Here we want the output from the last convolutional layer(excluding the last two fully connected layer)
        # The outputs are the features of the given images
        model = None
        if cnn_type == "resnet":
            model = models.resnet50(pretrained=True)
            for parameter in model.parameters():
                parameter.requires_grad_(False)
            modules = list(model.children())[:-2]
            self.cnn_model = nn.Sequential(*modules)

        elif cnn_type == "googlenet":
            model = models.googlenet(pretrained=True)
            for parameter in model.parameters():
                parameter.requires_grad_(False)

            modules = list(model.children())[:-1]
            self.cnn_model = nn.Sequential(*modules)

        elif cnn_type == "vggnet":
            model = models.vgg19(pretrained=True)
            for parameter in model.parameters():
                parameter.requires_grad_(False)

            modules = list(model.children())[:-2]
            self.cnn_model = nn.Sequential(*modules)




    def forward(self,images):
        """
        This function takes an pre-processed image as input and produce an encoded image output from CNN for RNN input.
        :param input images: image_pixel [batch_size,3,224,224]
        :return:
        """
        if self.cnn_type=="resnet":
            output_features = self.cnn_model(images) #[batch_size,2048,7,7]

            output_features = output_features.permute(0,2,3,1) #[batch_size,7,7,2048]

            # Flatten the height and width of the image
            output_features = output_features.view(output_features.size(0),-1,output_features.size(-1)) #[batch_size,49,2048]
        elif self.cnn_type=="googlenet":
            output_features = self.cnn_model(images)
            output_features = output_features.permute(0, 2, 3, 1)  # [batch_size,7,7,2048]

            # Flatten the height and width of the image
            output_features = output_features.view(output_features.size(0), -1,
                                                   output_features.size(-1))  # [batch_size,49,2048]

        elif self.cnn_type=="vggnet":
            output_features = self.cnn_model(images)
            output_features = output_features.permute(0, 2, 3, 1)  # [batch_size,7,7,2048]

            # Flatten the height and width of the image
            output_features = output_features.view(output_features.size(0), -1,
                                                   output_features.size(-1))  # [batch_size,49,2048]

        return output_features


In [None]:
#Attention机制
#Implement the attention techniques
import torch
import torch.nn as nn
import torch.nn.functional as F


# For testing
import numpy as np


class Attention(nn.Module):
    """
    This class established the Attention Mechanism in the encoded image, also used for visualization.
    """
    def __init__(self,encoder_dim,decoder_dim,attention_dim=None,attention_method="concat",attention_type="global"):
        super().__init__()
        assert attention_type in ["global","local"]
        assert attention_method in ["dot","general","concat"], "method error"

        self.method = attention_method
        self.type = attention_type
        self.attention_dim = attention_dim

        if self.type == "local":
            self.attention_dim = attention_dim
            self.Wa = nn.Linear(decoder_dim,attention_dim)
            self.Va = nn.Linear(encoder_dim,attention_dim)
            self.out = nn.Linear(attention_dim,1)

        else:
            if self.method == "general":
                self.Wa = nn.Linear(encoder_dim,decoder_dim,bias=False)
            elif self.method == "concat":
                self.Wa = nn.Linear(encoder_dim+decoder_dim,decoder_dim,bias=False)
                self.Va = nn.Linear(decoder_dim,1)


    def forward(self,hidden_state,encoder_outputs):
        """
        This function computes the attention_weights used for decoder RNN caption generation.
        :param hidden_state: input hidden_state from the last LSTM cell, [num_layer,batch_size,decoder_dim]
        :param encoder_outputs: outputs of the CNN encoder. [batch_size,seq_len,encoder_hidden_size]
        :return:
        """
        if self.method == "dot":
            return self.dot_score(hidden_state,encoder_outputs)

        elif self.method == "general":
            return self.general_score(hidden_state,encoder_outputs)

        elif self.method == "concat":
            return self.concat_score(hidden_state,encoder_outputs)


    def dot_score(self,hidden_state,encoder_outputs):
        """
        Depreciated, won't be used for the project
        :param hidden_state: [batch_size,decoder_dim]
        :param encoder_outputs: [batch_size,seq_len,encoder_dim]
        :return:
        """
        hidden_state = hidden_state.permute(1, 2, 0)  # [batch_size,decoder_dim,1]
        attention_weight = encoder_outputs.bmm(hidden_state).squeeze(-1)  # [batch_size,seq_len]
        attention_weight = F.softmax(attention_weight)
        context = encoder_outputs * attention_weight.unsqueeze(2)
        context = context.sum(dim=1)  # [batch_size,encoder_dim]
        return attention_weight,context


    def general_score(self,hidden_state,encoder_outputs):
        """
        Depreciated, won't be used for the project.
        :param hidden_state: [batch_size,decoder_dim]
        :param encoder_outputs: [batch_size,seq_len,encoder_dim]
        :return:
        """
        batch_size = encoder_outputs.size(0)
        encoder_seq_len = encoder_outputs.size(1)
        print(encoder_outputs.size())
        print(encoder_outputs.view(batch_size*encoder_seq_len,-1))

        encoder_outputs = self.Wa(encoder_outputs.view(batch_size*encoder_seq_len,-1))  # [batch_size*seq_len,decoder_dim]
        encoder_outputs = encoder_outputs.view(batch_size,encoder_seq_len,-1)  # [batch_size,seq_len,decoder_dim]
        hidden_state = hidden_state.permute(1, 2, 0)  # [batch_size,decoder_dim,1]
        attention_weight = encoder_outputs.bmm(hidden_state).squeeze(-1)  # [batch_size,seq_len]
        attention_weight = F.softmax(attention_weight)
        context = encoder_outputs * attention_weight.unsqueeze(2)
        context = context.sum(dim=1)  # [batch_size,encoder_dim]
        return attention_weight,context


    def concat_score(self,hidden_state,encoder_outputs):
        """
        Defining the alignment function mentioned by Luong et al.
        :param hidden_state: [batch_size,decoder_dim]
        :param encoder_outputs: [batch_size,seq_len,encoder_dim]
        :return:
        """
        # If we use the local attention
        if self.type=="local":
            encoder_out = self.Va(encoder_outputs) #[batch_size,seq_len,attention_dim]
            decoder_out = self.Wa(hidden_state) #[batch_size,attention_dim]

            combined_states = torch.tanh(encoder_out+decoder_out.unsqueeze(1))

            attention_scores = self.out(combined_states)

            attention_scores = attention_scores.squeeze(2)

            attention_weight = F.softmax(attention_scores,dim=1)

            context = encoder_out * attention_weight.unsqueeze(2)

            context = context.sum(dim=1)
        # If we use the global attention
        else:
            hidden_state = hidden_state.unsqueeze(1)
            hidden_state = hidden_state.repeat(1, encoder_outputs.size(1), 1)  # [batch_size,seq_len,decoder_dim]

            concated = torch.cat([hidden_state, encoder_outputs],
                                 dim=-1)  # [batch_size,seq_len,decoder_dim+encoder_dim]

            batch_size = encoder_outputs.size(0)
            encoder_seq_len = encoder_outputs.size(1)

            attention_weight = self.Va(torch.tanh(self.Wa(concated.view(batch_size*encoder_seq_len,-1)))).squeeze(-1)  # [batch_size*seq_len]
            attention_weight = attention_weight.view(batch_size,encoder_seq_len)

            attention_weight = F.softmax(attention_weight,dim=1)  # [batch_size,seq_len]
            context = encoder_outputs * attention_weight.unsqueeze(2) #[batch_size,seq_len,encoder_dim]
            context = context.sum(dim=1) #[batch_size,encoder_dim]
        return attention_weight,context


In [None]:
#DecoderRNN
import torch
import torch.nn as nn

class MyDecoderRNN(nn.Module):
    """
    This class Establish an attention-based Rnn Module for caption generation
    """
    def __init__(self,vocab_dim,
                 embedding_dim,
                 encoder_dim,
                 decoder_dim,
                 n_layers = 1,
                 dropout_p=0.5,
                 attention_dim=None,
                 attention_type = "global",
                 GRU = False):
        super().__init__()

        # 定义参数
        self.attention_type = attention_type
        self.vocab_dim = vocab_dim
        self.n_layers = n_layers
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.attention_dim = attention_dim
        self.GRU = GRU


        # 定义embed层，将输入的caption 经过embed处理之后用于decoder生成image captions
        self.embedding_layer = nn.Embedding(num_embeddings=VOCAB_DIM,
                                            embedding_dim=embedding_dim,
                                            padding_idx=PAD_IDX)

        # 定义attention层，产生attention_score 用于后续分配注意力
        if self.attention_type == "global":
            self.attn = Attention(encoder_dim,decoder_dim)

        elif self.attention_type == "local":
            self.attn = Attention(encoder_dim,decoder_dim,attention_dim=self.attention_dim,attention_type="local")

        # 定义正则化方法
        self.dropout = nn.Dropout(dropout_p)

        # 定义初始化层
        self.init_hidden = nn.Linear(encoder_dim,decoder_dim)
        self.init_cell = nn.Linear(encoder_dim,decoder_dim)

        # 定义每一个LSTM cell单元用于手动迭代
        if self.attention_type=="local":
            self.lstm_cell = nn.LSTMCell(embedding_dim+attention_dim,decoder_dim,bias=True)
        else:
            self.lstm_cell = nn.LSTMCell(embedding_dim+encoder_dim,decoder_dim,bias=True)


        if self.GRU:
            if self.attention_type == "local":
                self.gru_cell = nn.GRUCell(embedding_dim+attention_dim,decoder_dim,bias=True)
            else:
                self.gru_cell = nn.GRUCell(embedding_dim+encoder_dim,decoder_dim,bias=True)


        self.fcn = nn.Linear(decoder_dim,vocab_dim)


    def forward(self,image_features,captions):
        """

        :param image_features: encoder_outputs [batch_size,seq_len,encoder_dim]
        :param captions: numericalized captions list  [batch_size,max_len]
        :return:
        """

        embedded_captions = self.embedding_layer(captions) #[batch_size,embed_dim]

        # 初始化LSTM层
        # 对所有的features取平均用于初始化hidden_state和cell_state
        image_features_init = image_features.mean(dim=1)


        hidden_state = self.init_hidden(image_features_init)
        cell = self.init_cell(image_features_init)

        # 遍历所有时间步
        seq_len = len(captions[0])-1
        batch_size = captions.size(0)
        encoder_dim = image_features.size(1)

        # 初始化一个batch_size的所有的结果
        outputs = torch.zeros(batch_size,seq_len,self.vocab_dim).to(DEVICE)
        attention_weights = torch.zeros(batch_size,seq_len,encoder_dim).to(DEVICE)

        if self.GRU:
            for t in range(seq_len):
                attention_weight, context = self.attn(hidden_state, image_features)

                gru_input = torch.cat([embedded_captions[:, t], context], dim=1)

                hidden_state = self.gru_cell(gru_input, hidden_state)

                output = self.fcn(self.dropout(hidden_state))

                # 预测的词向量, output [batch_size,vocab_dim] ,attention_weight [batch_size,seq_len]
                outputs[:, t] = output
                attention_weights[:, t] = attention_weight

        else:
        #对于每一个lstm cell 我们都需要输入四个数据，hidden_state,cell,上一次 attention产生的context, 以及上一次的output(embedded之后的)
            for t in range(seq_len):

                attention_weight,context = self.attn(hidden_state,image_features)
                lstm_input = torch.cat([embedded_captions[:,t],context],dim=1)
                hidden_state, cell = self.lstm_cell(lstm_input,(hidden_state,cell))

                output = self.fcn(self.dropout(hidden_state))

                #预测的词向量, output [batch_size,vocab_dim] ,attention_weight [batch_size,seq_len]
                outputs[:,t] = output
                attention_weights[:,t] = attention_weight

        return outputs,attention_weights


    def generate_caption(self,image_features,max_len=15,vocabulary=dataset.word_dictionary):

        batch_size = image_features.size(0)

        image_features_init = image_features.mean(dim=1)
        hidden_state = self.init_hidden(image_features_init)
        cell = self.init_cell(image_features_init)


        # Starting to feed words into the RNN decoder by <SOS>
        word = torch.tensor(vocabulary.tokens_to_index["<SOS>"]).view(1,-1).to(DEVICE)

        #经过embed处理
        embedded = self.embedding_layer(word)

        attention_weights_list = []
        caption_outputs = []

        #达到最大句子长度限制就停止预测
        if self.GRU:
            for i in range(max_len):
                attention_weights, context = self.attn(hidden_state, image_features)

                # store the attention weights into the list
                attention_weights_list.append(attention_weights.cpu().detach().numpy())

                gru_input = torch.cat([embedded[:, 0], context], dim=1)

                hidden_state = self.gru_cell(gru_input, hidden_state)

                # Get a list with the likelihood of each word
                output = self.fcn(self.dropout(hidden_state))  # [batch_size,vocab_dim]

                predicted_word_index = output.argmax(dim=1)  # [batch_size,1]

                caption_outputs.append(predicted_word_index.item())

                # 遇到<EOS>就停止预测
                if dataset.word_dictionary.index_to_tokens[predicted_word_index.item()] == "<EOS>":
                    break

                # for the next iteration
                embedded = self.embedding_layer(predicted_word_index.unsqueeze(0))
        else:
            for i in range(max_len):
                attention_weights,context = self.attn(hidden_state,image_features)

                #store the attention weights into the list
                attention_weights_list.append(attention_weights.cpu().detach().numpy())


                lstm_input = torch.cat([embedded[:,0],context],dim=1)


                hidden_state,cell = self.lstm_cell(lstm_input,(hidden_state,cell))
                #hidden_state [batch_size,decoder_dim]
                #cell [batch_size,decoder_dim]


                # Get a list with the likelihood of each word
                output = self.fcn(self.dropout(hidden_state)) #[batch_size,vocab_dim]

                predicted_word_index = output.argmax(dim=1) #[batch_size,1]

                caption_outputs.append(predicted_word_index.item())

                # 遇到<EOS>就停止预测
                if dataset.word_dictionary.index_to_tokens[predicted_word_index.item()] == "<EOS>":
                    break

                # for the next iteration
                embedded = self.embedding_layer(predicted_word_index.unsqueeze(0))


        caption_outputs = [dataset.word_dictionary.index_to_tokens[index] for index in caption_outputs]


        return caption_outputs,attention_weights_list


In [None]:
#Model
import torch.nn as nn

class MyModel(nn.Module):
    """
    This class encapsulates the final encoder-decoder model.
    """
    def __init__(self,cnn_type="resnet",attention_dim=None,attention_type="global",GRU=False):
        super().__init__()
        self.encoder = MyEncoderCNN(cnn_type=cnn_type)
        self.decoder = MyDecoderRNN(
            vocab_dim=VOCAB_DIM,
            embedding_dim=EMBED_DIM,
            encoder_dim= ENCODER_DIM,
            decoder_dim= DECODER_DIM,
            attention_dim=attention_dim,
            attention_type=attention_type,
            GRU=GRU
        )


    def forward(self,imgs,captions):
        image_features = self.encoder(imgs)
        outputs,attention_weights = self.decoder(image_features,captions)
        return outputs,attention_weights


In [None]:
loss_list_ResNetGlobalLSTM = []
loss_list_ResNetGlobalGRU = []
loss_list_VGG19GlobalLSTM = []
loss_list_VGG19GlobalGRU = []

In [None]:
from torch.optim import Adam,RMSprop,SGD,ASGD,Adagrad
import torch.nn as nn
import torch
import matplotlib.pyplot as plt



def save_model(model,epochs,model_name):
    """
    This function saves the model's parameters
    :param model:
    :param epoch:
    :return:
    """
    model_param = {
        "epochs":epochs,
        "vocab_size":VOCAB_DIM,
        "embed_size":EMBED_DIM,
        "encoder_dim":ENCODER_DIM,
        "decoder_dim":DECODER_DIM,
        "attention_dim":ATTENTION_DIM,
        "model_state_dict":model.state_dict()
    }

    torch.save(model_param,"./模型存放/{}.pkl".format(model_name))


def evaluation(epoch,model,model_name,loss_list,display_steps=20):
    """
    This function serves to dynamically evaluate the model's performance.
    :param epoch:
    :param display_steps:
    :return:
    """

    for idx, (image, targets) in enumerate(get_data_loader(train=True)):
        # image [batch_size,seq_len,encoder_dim]
        # targets [batch_size,max_len]

        image, targets = image.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()

        # outputs [batch_size,seq_len,vocab_size]
        outputs, attention_weights = model(image, targets)

        target = targets[:, 1:]  # 取<SOS>之后的文本序列,target [batch_size,seq_len]

        # 计算一个batch上的交叉损失，用于backpropagation
        l = criterion_metrics(outputs.view(-1, VOCAB_DIM), target.reshape(-1))

        l.backward()

        optimizer.step()

        if (idx + 1) % (display_steps) == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch + 1, l.item()))
            loss_list.append(l.item())
            # 切换成评估模式(忽略dropout)
            model.eval()
            with torch.no_grad():
                data_loader = iter(get_data_loader(train=False))
                image, caption = next(data_loader)
                image_features = model.encoder(image[0:1].to(DEVICE))

                captions_result, attention_weights = model.decoder \
                    .generate_caption(image_features)

                sentence = " ".join(captions_result)

                show_image(image[0], label=sentence)
            model.train()
        save_model(model,epoch,model_name)

In [None]:
#ResNetGlobalLSTM
seq2seq = MyModel().to(DEVICE)


optimizer = Adam(seq2seq.parameters(),lr=LEARNING_RATE)


criterion_metrics = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
%%time
for i in range(2):
    evaluation(i,"ResNetGlobalLSTM",loss_list_ResNetGlobalLSTM,50)