In [2]:
import os
import argparse
import multiprocessing as mp
from sys import argv

import torch
import numpy as np
import pytorch_lightning as pl
from tqdm import tqdm
from pytorch_lightning.loggers import NeptuneLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5EncoderModel
from transformers import AdamW

from Datasets import WikiTable
from metrics import compute_exact, compute_f1

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(out_channels),
                                   nn.LeakyReLU(),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1),
                                   nn.BatchNorm2d(out_channels),
                                   nn.LeakyReLU())

    def forward(self, x):
        return self.block(x)

In [None]:
class CNNTransformer(pl.LightningModule):
    def __init__(self, params):
        super().__init__()
        
        #Parameters stored in dictionary
        self.params = params
        
        #Tokenizer for decoding sentences
        self.tokenizer = T5Tokenizer.from_pretrained(self.params.t5_model)
        
        #Sentence encoder -> just transformer encoder for questions 
        self.sentence_encoder = T5EncoderModel.from_pretrained(self.params.t5_model)
        
        #Decoder -> Decode image embedding combined with the last hidden state of the encoder
        self.decoder = T5ForConditionalGeneration.from_pretrained(self.params.t5_model)
        
        # Feature adapter for combining image features and transformer 
        # last hidden state from transformer encoder (question)
        self.adapter = nn.Linear(self.params.hidden_dim, self.decoder.config.d_model)
        
        #to align channel number with language decoder
        self.CNNEmbedder = nn.Sequential(ConvBlock(3, 16),
                                     ConvBlock(16, 64),
                                     ConvBlock(64, 256),
                                     ConvBlock(256, self.decoder.config.d_model))
        
    
    def forward(self, batch):
        table_imgs = batch['table_img']
        questions_ids = batch['question_ids']
        questions = batch['question']
        questions_attn_mask = batch['question_attn_mask']
        answers = batch['answer']
        target_ids = batch['target_ids']
        
        
        #obtain the sentence encoder outputs
        encoder_outputs = self.sentence_encoder(input_ids=questions_ids,
                                                attention_mask=questions_attn_mask,
                                                output_attentions=self.params.ret_encoder_attn)
        
        #batch size x seqlen x self.d_model
        encoder_hidden_state = encoder_outputs.last_hidden_state
        
        #perhaps use attention coming from the sentence encoder
        encoder_attn = encoder_outputs.attentions if self.params.ret_encoder_attn else None 
    
        
        #get image embeddings from CNN
        #img_features.shape (B, d_model, H, W)
        img_features = self.CNNEmbedder(table_imgs)
        
        #(B, C, H, W)
        B = table_imgs.size(0)
        
        #torch.Size([B, H, W, C]) -> torch.Size([B, H*W, C])
        img_features = features.img_features(0, 2, 3, 1).view(B, -1, self.decoder.config.d_model)
        
        #torch.Size([B, H*W + seqlen, d_model])
        combined_feat = torch.cat([img_features, encoder_hidden_state], dim=1)
        
        #torch.Size([B, d_model, H*W + seqlen])
        combined_feat = combined_feat.permute(0, 2, 1)
        
        #torch.Size([B, d_model, hidden_dim=hidden_state_dim])
        proj_features = self.adapter(combined_feat)
        
        #torch.Size([B, hidden_state_dim, d_model]) -- ready to be used as last hidden state!    
        proj_features = proj_features.permute(0, 2, 1)
        
        if self.training:
            loss = self.decoder(encoder_outputs=(proj_features, encoder_attn),
                               labels=target_ids).loss
            
            return loss
        else:
            #TODO: FAZER AQUI
            self.generate_predictions()
        
    def training_step(self, batch, batch_idx):
        loss = self(batch)
        
        self.log('loss', torch.Tensor([loss], device=self.device), on_epoch=True, on_step=True, 
                 sync_dict = True if self.params.gpus > 1 else False)
        return loss
    
    #TODO: FAZER AQUI
    def generate_prediction(self):
        pass
    
    #TODO: FAZER AQUI
    def evaluation_step(self, batch):
        '''
        Same step for validation and testing.
        '''
        pred_token_phrases = self(batch)
        preds = self.tokenizer.batch_decode(pred_tokens, skip_special_tokens=True) 

        return batch["answer"], preds
    
    #TODO: FAZER AQUI
    def validation_step(self, batch, batch_idx):
        return self.evaluation_step(batch)

    #TODO: FAZER AQUI
    def test_step(self, batch, batch_idx):
        return self.evaluation_step(batch)
    
    #TODO: FAZER AQUI
    def epoch_end(self, outputs, phase):
        tgts, preds = [], []
        for output in outputs:
            tgts += output[0]
            preds += output[1]

        f1s, exacts = [], []
        for tgt, pred in zip(tgts, preds):
            f1s.append(compute_f1(tgt, pred))
            exacts.append(compute_exact(tgt, pred))

        self.log_dict({f"{phase}_f1": np.array(f1s).mean(), f"{phase}_exact_match": np.array(exacts).mean()},
                      prog_bar=True, on_step=False, on_epoch=True)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    def train_dataloader(self):
        return DataLoader(WikiText(Modes.TRAIN, self.tokenizer, batch_size=self.params.batch_size, 
                                   shuffle=True, num_workers=self.params.nworkers)

    def val_dataloader(self):
        return DataLoader(WikiTable(Modes.VAL, self.tokenizer, batch_size=self.params.batch_size, 
                                    shuffle=False, num_workers=self.params.nworkers)

    def test_dataloader(self):
        return DataLoader(WikiTable(Modes.TEST, self.tokenizer, batch_size=self.hparams.batch_size, 
                                    shuffle=False, num_workers=self.params.nworkers)