In [None]:
# !pip install -q pytorch-lightning
# !pip install transformers
# !pip install librosa
# !pip install wandb

In [1]:
import librosa
import math
import os
import pandas as pd
import pathlib
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
from collections import OrderedDict 

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn import functional as F
from torchvision import datasets, models, transforms
from torchvision.models.utils import load_state_dict_from_url
import torchmetrics
from torchmetrics import Accuracy, MetricCollection, Precision, Recall, F1
from transformers import Wav2Vec2Model, Wav2Vec2Processor

from typing import Type, Any, Callable, Union, List, Dict, Optional, cast

import wandb

pl.utilities.seed.seed_everything(42)

Global seed set to 42


42

In [2]:
class AudioDataset(Dataset):
    def __init__(self, annotations_path, root_dir, tokenizer):
        self.annotations = pd.read_csv(annotations_path, delimiter=',')
        self.root_dir = pathlib.Path(root_dir)
        self.tokenizer = tokenizer        
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        data_row = self.annotations.iloc[idx]
        audio, _ = librosa.load((self.root_dir/data_row['Vid_name']).with_suffix('.wav'), sr = 16000)
        tokenized_audio = self.tokenizer(audio, sampling_rate=16000, return_tensors="pt", max_length=129600, padding="max_length").input_values
        if tokenized_audio.shape[1]!=129600:
            tokenized_audio = tokenized_audio[:, :129600]
        label = data_row['Label']
        return {'sample': tokenized_audio, 'label': label-1}

In [3]:
class VideoDataset(Dataset):
    def __init__(self, annotations_path, root_dir, transformation):
        self.annotations = pd.read_csv(annotations_path, delimiter=',')
        self.root_dir = pathlib.Path(root_dir)
        self.transform = transformation

    def __len__(self):
        return len(self.annotations)
        
    def __getitem__(self, idx):
        data_row = self.annotations.iloc[idx]
        video_name = (self.root_dir/data_row['Vid_name'])
        images = [self.transform(Image.open(img_path)) for img_path in sorted(pathlib.Path(video_name).iterdir())]
        stacked_images = torch.stack([pic for pic in images])
        label = data_row['Label']
        sample = {'sample': stacked_images, 'label': label-1}
        return sample

In [4]:
class MetaDataset(Dataset):
    def __init__(self,audio_dataset, video_dataset, annotations_path):
        self.annotations = pd.read_csv(annotations_path, delimiter=',')
        self.audio_dataset = audio_dataset
        self.video_dataset = video_dataset
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        audio_dataset =  self.audio_dataset[idx]
        video_dataset =  self.video_dataset[idx]
        return {'audio': audio_dataset, 'video': video_dataset}

In [5]:
class VgafDataModule(pl.LightningDataModule):

    def __init__(self, audio_data_dir='caer_vgaf_audio', vid_data_dir=pathlib.Path('caer_vgaf_frames'), 
                batch_size=16):
        super().__init__()        
        self.audio_data_dir = pathlib.Path(audio_data_dir)
        self.vid_data_dir = pathlib.Path(vid_data_dir)
        self.batch_size = batch_size
        
        self.train_labels = pathlib.Path('Splitted_annotations_vgaf_caer/train_labels')
        self.validation_labels = pathlib.Path('Splitted_annotations_vgaf_caer/val_labels')
        self.test_labels = pathlib.Path('Splitted_annotations_vgaf_caer/test_labels')
        
        self.tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        
        self.train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(p=0.4), 
                transforms.RandomVerticalFlip(p=0.2),
                transforms.Resize((224, 224)),    
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
            ])
        
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),      
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            transforms.Resize((224, 224))
        ])


    def train_dataloader(self):        
        return DataLoader(MetaDataset(annotations_path=self.train_labels,
                           audio_dataset = AudioDataset(annotations_path=self.train_labels,
                            root_dir=self.audio_data_dir, 
                            tokenizer=Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")),
                                      
                         video_dataset = VideoDataset(annotations_path=self.train_labels,
                            root_dir=self.vid_data_dir, 
                            transformation=self.train_transform)), batch_size=self.batch_size, shuffle=True, drop_last=True)
        
        
    def val_dataloader(self):
        return DataLoader(MetaDataset(annotations_path=self.validation_labels,
                   audio_dataset = AudioDataset(annotations_path=self.validation_labels,
                    root_dir=self.audio_data_dir, 
                    tokenizer=Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")),

                 video_dataset = VideoDataset(annotations_path=self.validation_labels,
                    root_dir=self.vid_data_dir, 
                    transformation=self.val_transform)), batch_size=self.batch_size, shuffle=False, drop_last=True)
    

    def test_dataloader(self):
        return DataLoader(MetaDataset(annotations_path=self.test_labels,
                   audio_dataset = AudioDataset(annotations_path=self.test_labels,
                    root_dir=self.audio_data_dir, 
                    tokenizer=Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")),

                 video_dataset = VideoDataset(annotations_path=self.test_labels,
                    root_dir=self.vid_data_dir, 
                    transformation=self.transform)), batch_size=self.batch_size, shuffle=False, drop_last=True)

# Trying to use pretrained resnet

In [6]:
from torchvision.models.resnet import *
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.resnet import model_urls

In [7]:
class IntResNet(ResNet):
    def __init__(self,output_layer,*args):
        self.output_layer = output_layer
        super().__init__(*args)
        
        self._layers = []
        for l in list(self._modules.keys()):
            self._layers.append(l)
            if l == output_layer:
                break
        self.layers = OrderedDict(zip(self._layers,[getattr(self,l) for l in self._layers]))


    def _forward_impl(self, x):
        for l in self._layers:
            x = self.layers[l](x)
            
        return x

    def forward(self, x):
        return self._forward_impl(x)

In [8]:
class NewModel(nn.Module):
    def __init__(self,base_model,base_out_layer,num_trainable_layers):
        super().__init__()
        self.base_model = base_model
        self.base_out_layer = base_out_layer
        self.num_trainable_layers = num_trainable_layers
        
        self.model_dict = {'resnet18':{'block':BasicBlock,'layers':[2,2,2,2],'kwargs':{}},
                           'resnet34':{'block':BasicBlock,'layers':[3,4,6,3],'kwargs':{}},
                           'resnet50':{'block':Bottleneck,'layers':[3,4,6,3],'kwargs':{}},
                           'resnet101':{'block':Bottleneck,'layers':[3,4,23,3],'kwargs':{}},
                           'resnet152':{'block':Bottleneck,'layers':[3,8,36,3],'kwargs':{}},
                           'resnext50_32x4d':{'block':Bottleneck,'layers':[3,4,6,3],
                                              'kwargs':{'groups' : 32,'width_per_group' : 4}},
                           'resnext101_32x8d':{'block':Bottleneck,'layers':[3,4,23,3],
                                               'kwargs':{'groups' : 32,'width_per_group' : 8}},
                           'wide_resnet50_2':{'block':Bottleneck,'layers':[3,4,6,3],
                                              'kwargs':{'width_per_group' : 64 * 2}},
                           'wide_resnet101_2':{'block':Bottleneck,'layers':[3,4,23,3],
                                               'kwargs':{'width_per_group' : 64 * 2}}}
        
        #PRETRAINED MODEL
        self.resnet = self.new_resnet(self.base_model,self.base_out_layer,
                                     self.model_dict[self.base_model]['block'],
                                     self.model_dict[self.base_model]['layers'],
                                     True,True,
                                     **self.model_dict[self.base_model]['kwargs'])

        self.layers = list(self.resnet._modules.keys())
        #FREEZING LAYERS
        self.total_children = 0
        self.children_counter = 0
        for c in self.resnet.children():
            self.total_children += 1
            
        if self.num_trainable_layers == -1:
            self.num_trainable_layers = self.total_children
        
        for c in self.resnet.children():
            if self.children_counter < self.total_children - self.num_trainable_layers:
                for param in c.parameters():
                    param.requires_grad = False
            else:
                for param in c.parameters():
                    param.requires_grad =True
            self.children_counter += 1
                    
    def new_resnet(self,
                   arch: str,
                   outlayer: str,
                   block: Type[Union[BasicBlock, Bottleneck]],
                   layers: List[int],
                   pretrained: bool,
                   progress: bool,
                   **kwargs: Any
                  ) -> IntResNet:

        '''model_urls = {
            'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
            'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
            'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
            'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
            'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
            'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
            'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
            'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
            'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
        }'''

        model = IntResNet(outlayer, block, layers, **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url(model_urls[arch],
                                                  progress=progress)
            model.load_state_dict(state_dict)
        return model
    
    def forward(self,x):
        x = self.resnet(x)
        return x

# Model

In [9]:
class BertAttention(nn.Module):
    def __init__(self, num_heads=4, ctx_dim=256):
        super().__init__()
        self.num_attention_heads = 4
        self.hidden_size = 256
        if self.hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (self.hidden_size, self.num_attention_heads))
        self.num_attention_heads = self.num_attention_heads
        self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # visual_dim = 2048
        if ctx_dim is None:
            ctx_dim =self.hidden_size
        self.query = nn.Linear(self.hidden_size, self.all_head_size)
        self.key = nn.Linear(ctx_dim, self.all_head_size)
        self.value = nn.Linear(ctx_dim, self.all_head_size)

        self.dropout = nn.Dropout(p=0.2)

    def transpose_for_scores(self, x):
        bsz, num_feat, hsz = x.shape
        x = x.view(bsz, num_feat, self.num_attention_heads, self.attention_head_size)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, context, attention_mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(context)
        mixed_value_layer = self.value(context)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        attention_probs = nn.Softmax(dim=-1)(attention_scores)


        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer

In [10]:
class BertAttOutput(nn.Module):
    def __init__(self):
        super(BertAttOutput, self).__init__()
        self.hidden_size = 256
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.LayerNorm = BertLayerNorm(self.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [11]:
class BertSelfattLayer(nn.Module):
    def __init__(self):
        super(BertSelfattLayer, self).__init__()
        self.self = BertAttention()
        self.output = BertAttOutput()

    def forward(self, input_tensor, attention_mask):
        self_output = self.self(input_tensor, input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output

In [12]:
class BertCrossattLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.att = BertAttention()
        self.output = BertAttOutput()

    def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
        input_tensor = input_tensor.permute(1, 0, 2)
        ctx_tensor = ctx_tensor.permute(1, 0, 2)
        output = self.att(input_tensor, ctx_tensor, ctx_att_mask)
        attention_output = self.output(output, input_tensor)
        return attention_output

In [13]:
def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class GeLU(nn.Module):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return gelu(x)
    

In [14]:
wandb_logger = WandbLogger(project='Sum_paper', log_model='all')
BertLayerNorm = torch.nn.LayerNorm

In [15]:
config = dict (
  learning_rate =  0.000209,
  batch_size = 16,
)

In [16]:
class MultTransfromer(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model_ft = NewModel('resnet50','layer3',num_trainable_layers = 4) #only 3rd layer being trained
        self.wav2vec = nn.Sequential(Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")) # outupt torch.Size([Batch, 251, 768])

        self.visual_attention = BertCrossattLayer()
    
        self.lang_self_att = BertSelfattLayer()
        self.visn_self_att = BertSelfattLayer()
    

        self.to_trans_vid = nn.Sequential(
            nn.LayerNorm(196),
            nn.Linear(196, 256),
            nn.ReLU(),
            nn.Dropout(0.15)            
        )
    
    
        self.to_trans_aud = nn.Sequential(
            nn.LayerNorm(768),
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(512, 256)
        )
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(512, 3),

        )
        
        self.lr = config['learning_rate']
        self.batch_size = config['batch_size']
        
        self.pool_vid = nn.AdaptiveAvgPool1d(1)
        self.pool_aud = nn.AdaptiveAvgPool1d(1)
        
        self.metrics =  MetricCollection([Accuracy(), Precision(num_classes=3, average='macro'), 
                                           Recall(num_classes=3, average='macro'), F1(num_classes=3, average='macro')])
        
    def cross_entropy_loss(self, logits, labels):
        return nn.CrossEntropyLoss()(logits, labels)
    
    
    def self_att(self, lang_input, visn_input, lang_attention_mask=None, visn_attention_mask=None):
        # Self Attention
        lang_att_output = self.lang_self_att(lang_input, lang_attention_mask)
        visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
        return lang_att_output, visn_att_output

    
    def cross_att(self, lang_input,  visn_input, lang_attention_mask=None, visn_attention_mask=None):
        # Cross Attention
        lang_att_output = self.visual_attention(lang_input, visn_input, ctx_att_mask=visn_attention_mask)
        visn_att_output = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
        return lang_att_output, visn_att_output
    
    
    def forward(self, x): 
        x_video, y_video = x['video']['sample'], x['video']['label'] #(B, n_frm, H, W, C)
        x_audio, y_audio = x['audio']['sample'], x['audio']['label']


        # попробовать mean по каналам (-1) а не по времени
        x_video = x_video.mean(1) #(B, H, W, C)
        x_video = self.model_ft(x_video) #(B, d, H, W)
        x_video = x_video.view(self.batch_size, -1, 14*14).permute(1,0,2) #(d, Batch, H*W)
        x_video = self.to_trans_vid(x_video) #(B, num_features)
        x_audio = self.wav2vec(x_audio.squeeze()).last_hidden_state 
        x_audio = x_audio.permute(1, 0, 2) #(d, B, features)
        x_audio = self.to_trans_aud(x_audio)

        
        lang_att_output, visn_att_output = self.self_att(x_audio, x_video) # N.B. first self attention and then cross
        lang_att_output, visn_att_output = self.cross_att(lang_att_output, visn_att_output) #check this stuff and rethink about this block maybe
        lang_att_output, visn_att_output = self.self_att(lang_att_output, visn_att_output) # N.B. first self attention and then cross
        
        
        lang_att_output = x_audio.permute(1, 0, 2) + lang_att_output
        visn_att_output = x_video.permute(1, 0, 2) + visn_att_output
        
        
        lang_att_output = self.pool_aud(lang_att_output.permute(0, 2, 1)).squeeze()
        visn_att_output = self.pool_vid(visn_att_output.permute(0, 2, 1)).squeeze()

        
        concated_features = torch.cat((visn_att_output, lang_att_output), dim=1)
        logits = self.classifier(concated_features)
        
        return logits
        
        

    def training_step(self, batch, batch_idx):
        _, loss, metrics = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_metrics', metrics)
        return loss


    def validation_step(self, batch, batch_idx):
        preds, loss, metrics = self._get_preds_loss_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_metrics', metrics)
        return preds

    def test_step(self, batch, batch_idx):

        _, loss, metrics = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_metrics', metrics)


    def configure_optimizers(self):

        matches = ['visual_attention', 'visn_self_att', 'lang_self_att']

        
        params_1x = [param for name, param in model.named_parameters()
            if name not in matches]
        params_low_lr = [param for name, param in model.named_parameters()
            if name in matches]
        optimizer = torch.optim.Adam([
            {'params': params_1x}, 
            {'params': params_low_lr, 'lr': self.lr * 0.1}
        ], lr=self.lr, weight_decay=0.001)
        
#         scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.99 ** epoch)
        
        return {
        'optimizer': optimizer,
#         'lr_scheduler': scheduler,
        'monitor': 'val_loss',
    }
    
    def on_validation_end(self):
        wandb.save('wandb/latest-run/files/*checkpoints')
    
    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''

        y = batch['video']['label']
        logits = self.forward(batch)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        m = nn.Softmax(dim=1)
        softmax_preds = m(logits)
        metrics = self.metrics(softmax_preds, y)
        return preds, loss, metrics

In [18]:
dm = VgafDataModule()
model = MultTransfromer()
# count_parameters(model)
wandb_logger.watch(model)

wandb.init(project="Sum_paper", resume='1ldxtiqc')


early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=6,
    verbose=False,
    mode='min',
    check_on_train_epoch_end=True
)

checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                      save_top_k=3,
                                      save_last=True,
                                      save_weights_only=False,
                                      filename='checkpoint/{epoch:02d}-{val_loss:.4f}-{val_f1:.4f}',
                                      verbose=False,
                                      mode='min')


# trainer = pl.Trainer(
#     logger=wandb_logger, 
#                      gpus=-1, 
#     deterministic=True,
#     callbacks=[early_stop_callback, checkpoint_callback],
# #     auto_lr_find=True,
# #                      auto_scale_batch_size='binsearch',
#                        )

# trainer.fit(model, dm)
# lr_finder = trainer.tune(model, datamodule=dm)
trainer = pl.Trainer(
    logger=wandb_logger, 
                     gpus=-1, 
    deterministic=True,
    callbacks=[early_stop_callback, checkpoint_callback],resume_from_checkpoint=wandb.restore('wandb/latest-run/files/Sum_paper/1ldxtiqc/checkpoints/last.ckpt'))

trainer.fit(model, dm)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

  rank_zero_warn(
Global seed set to 42
  rank_zero_warn(


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






In [24]:
# trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name             | Type              | Params
--------------------------------------------------------
0  | model_ft         | NewModel          | 25.6 M
1  | wav2vec          | Sequential        | 94.4 M
2  | visual_attention | BertCrossattLayer | 263 K 
3  | lang_self_att    | BertSelfattLayer  | 263 K 
4  | visn_self_att    | BertSelfattLayer  | 263 K 
5  | to_trans_vid     | Sequential        | 50.8 K
6  | to_trans_aud     | Sequential        | 526 K 
7  | classifier       | Sequential        | 2.6 K 
8  | pool_vid         | AdaptiveAvgPool1d | 0     
9  | pool_aud         | AdaptiveAvgPool1d | 0     
10 | metrics          | MetricCollection  | 0     
--------------------------------------------------------
119 M     Trainable params
1.4 M     Non-trainable params
121 M     Total params
485.199   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Global seed set to 42


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

In [22]:
wandb.save('wandb/latest-run/files/Sum_paper/249h1fdz/checkpoints/*ckpt*')

['/workspace/wandb/run-20210825_204918-249h1fdz/files/wandb/latest-run/files/Sum_paper/249h1fdz/checkpoints/last.ckpt',
 '/workspace/wandb/run-20210825_204918-249h1fdz/files/wandb/latest-run/files/Sum_paper/249h1fdz/checkpoints/last.ckpt']

In [None]:
trainer.test(datamodule=dm)

In [None]:
# model.lr # 0.0002089296130854041