In [1]:
# %pip install -r requirements.txt

In [2]:
# %pip install transformers

In [3]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from collections import Counter
import os

  from .autonotebook import tqdm as notebook_tqdm


### Annotaions

In [4]:
# Load the annotations
annotations_train_path = 'data/processed_annotations_train.csv'
annotations_train = pd.read_csv(annotations_train_path)

annotations_val_path = 'data/processed_annotations_valid.csv'
annotations_val = pd.read_csv(annotations_val_path)

In [5]:
print(annotations_train.head())

                                       attachment_id text  \
0  data/tensors\train_9ede40ec-0c67-4a57-8337-003...    Н   
1  data/tensors\train_8fb4e5c9-a7df-4bad-b7d7-9d8...    Н   
2  data/tensors\train_d228da91-495a-4d6d-833a-64b...    Н   
3  data/tensors\train_397a486d-967f-444b-afd8-de1...    Н   
4  data/tensors\train_e1200d57-9b53-4ffa-ab7d-e3e...    Н   

                            user_id  height  width  length  train  
0  185bd3a81d9d618518d10abebf0d17a8    1920   1080    95.0   True  
1  46dd04a1caa75ed3082b573cb5a3ad26    1920   1080    37.0   True  
2  db573f94204e56e0cf3fc2ea000e5bdc    1280    720    45.0   True  
3  0211b488644476dd0fec656ccb9b74fc    1920   1080    52.0   True  
4  b07a773bcb10b4f14f33d2b0e8ec58ba     720   1280    69.0   True  


In [6]:
print(annotations_val.head())

                                       attachment_id text  \
0  data/tensors\valid_d1131f18-57d2-412f-8261-aca...    Н   
1  data/tensors\valid_016df4ed-6e78-40a1-974f-e5d...    Н   
2  data/tensors\valid_45557bda-2f3a-4882-a430-9cf...    Н   
3  data/tensors\valid_ab6ef8b7-1971-4218-ad0f-bfe...    Н   
4  data/tensors\valid_49d3d3c3-7288-4a05-96bf-76d...    Н   

                            user_id  height  width  length  train  
0  0ab4f8e463cdded2e59d6001f4e1b487    1080   1920    42.0  False  
1  1860c7f3605089bf0eff0c4e72c7add4    1440   1440    41.0  False  
2  7c989fdb633be7453228f7c650097205    1920   1080    30.0  False  
3  ba230d5b794c9a7350356eddd3e9929c    1920    960    41.0  False  
4  3dd2ce2659aada17b976390004ebe322    1920    886    39.0  False  


### Pretrained models

In [7]:
# # Load Swin Transformer and MViTv2 models as generic models
swin_transformer_model = torch.jit.load('models/swin16-3.pt')
mvitv2_model = torch.jit.load('models/mvit16-4.pt')

In [8]:
# Make sure to call eval() if you're using the models for inference
swin_transformer_model.eval()
mvitv2_model.eval()

RecursiveScriptModule(
  original_name=Recognizer3D
  (data_preprocessor): RecursiveScriptModule(original_name=ActionDataPreprocessor)
  (backbone): RecursiveScriptModule(
    original_name=MViT
    (patch_embed): RecursiveScriptModule(
      original_name=PatchEmbed3D
      (projection): RecursiveScriptModule(original_name=Conv3d)
    )
    (blocks): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=MultiScaleBlock
        (norm1): RecursiveScriptModule(original_name=LayerNorm)
        (attn): RecursiveScriptModule(
          original_name=MultiScaleAttention
          (qkv): RecursiveScriptModule(original_name=Linear)
          (proj): RecursiveScriptModule(original_name=Linear)
          (pool_q): RecursiveScriptModule(original_name=Conv3d)
          (norm_q): RecursiveScriptModule(original_name=LayerNorm)
          (pool_k): RecursiveScriptModule(original_name=Conv3d)
          (norm_k): RecursiveScriptModule(original_name

In [9]:
additional_models = {
    # 'swin_transformer': swin_transformer_model,
    'mvitv2': mvitv2_model
}

In [10]:
def adjust_inputs_for_model(inputs):
    adjusted_inputs = inputs.squeeze()
    return adjusted_inputs

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [12]:
class PaddedSignLanguageDataset(Dataset):
    def __init__(self, annotations, transform=None, max_length=None):
        """
        Corrected custom dataset for loading sign language video tensors with padding.

        Args:
        annotations (DataFrame): DataFrame containing the annotations.
        transform (callable, optional): Optional transform to be applied on a sample.
        max_length (int, optional): Maximum length of the video tensors. If not provided, it will be calculated.
        """
        self.annotations = annotations
        self.transform = transform
        self.max_length = 132

        if self.max_length is None:
            # Calculate the maximum length among all tensors
            self.max_length = max(len(torch.load(row['attachment_id'], map_location=torch.device('cpu'))) for _, row in annotations.iterrows())

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        tensor_path = self.annotations.iloc[idx]['attachment_id']
        label = self.annotations.iloc[idx]['text']
        
        # Load the tensor
        tensor = torch.load(tensor_path, map_location=torch.device('cpu'))

        # Check if the tensor is empty or None
        if tensor is None or len(tensor) == 0:
            print(f"Empty tensor found at index {idx}.")
            return None, label


        # Pad the tensor to the maximum length
        padded_tensor = torch.zeros((self.max_length, *tensor[0].shape))
        padded_tensor[:len(tensor)] = torch.stack(tensor)
                
        # Apply transform if any
        if self.transform:
            padded_tensor = self.transform(padded_tensor)

        return padded_tensor, label

In [13]:
# Create the padded dataset and dataloader
padded_dataset_train = PaddedSignLanguageDataset(annotations_train)
padded_dataloader_train = DataLoader(padded_dataset_train, batch_size=4, shuffle=True)

padded_dataset_val = PaddedSignLanguageDataset(annotations_val)
padded_dataloader_val = DataLoader(padded_dataset_val, batch_size=4, shuffle=True)

### Labels Mapping

In [14]:
# Initialize an empty set to collect unique labels
unique_labels_train = set()

# Iterate over your dataset to collect unique labels
for _, label_data in padded_dataloader_train:
    unique_labels_train.update(label_data)

# Sort the labels for consistency
sorted_labels = sorted(unique_labels_train)

# Create the label mapping
train_label_mapping = {label: idx for idx, label in enumerate(sorted_labels)}

# Print the label mapping
print("Label Mapping:", train_label_mapping)

Label Mapping: {'Н': 0, 'вероятно': 1, 'значить': 2, 'обычный': 3, 'овца': 4, 'продлевать': 5, 'суматоха': 6, 'такой': 7, 'цветовой оттенок': 8, 'черный': 9}


In [15]:
# Initialize an empty set to collect unique labels
unique_labels_test = set()

# Iterate over your dataset to collect unique labels
for _, label_data in padded_dataloader_val:
    unique_labels_test.update(label_data)

# Sort the labels for consistency
sorted_labels = sorted(unique_labels_test)

# Create the label mapping
val_label_mapping = {label: idx for idx, label in enumerate(sorted_labels)}

# Print the label mapping
print("Label Mapping:", val_label_mapping)

Label Mapping: {'Н': 0, 'вероятно': 1, 'значить': 2, 'обычный': 3, 'овца': 4, 'продлевать': 5, 'суматоха': 6, 'такой': 7, 'цветовой оттенок': 8, 'черный': 9}


In [16]:
from classes import classes

In [17]:
results = {}

In [18]:
class TorchScriptModelWrapper(nn.Module):
    def __init__(self, scripted_model, output_classes=10):
        super(TorchScriptModelWrapper, self).__init__()
        self.scripted_model = scripted_model
        # Assuming you know the number of features from the original model's output
        self.new_fc = nn.Linear(1001, output_classes)

    def forward(self, x):
        x = self.scripted_model(x)
        # The output of the scripted model is then passed through the new fully connected layer
        x = self.new_fc(x)
        return x

In [19]:
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

In [20]:
for model_name, model in additional_models.items():
    if torch.jit.isinstance(model, torch.jit.ScriptModule):
        additional_models[model_name] = TorchScriptModelWrapper(model)

    model = additional_models[model_name]
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_function = nn.CrossEntropyLoss()
    model.eval()

    val_correct_predictions = 0
    val_total_predictions = 0

    for epoch in range(10):
        with tqdm(enumerate(padded_dataloader_val), total=len(padded_dataloader_val), desc=f"Model: {model_name}") as t_outer:
            for i, data in t_outer:
                inputs, label_data = data
                if len(inputs) == 0 or len(label_data) == 0:
                    continue

                labels = torch.tensor([val_label_mapping[label] for label in label_data], dtype=torch.long)

                batch_corr = 0.
                batch_total = 0.

                with tqdm(enumerate(inputs), total=len(inputs), leave=False, desc="Processing inputs") as t_inner:
                    for e, inp in t_inner:
                        target = labels[e]
                        target_probs = torch.zeros(1, 10)
                        target_probs[0][target] = 1
                        j = 0
                        inp_pred = []
                        num_splits = inp.shape[0] // 16
                        while j + num_splits <= inp.shape[0]:
                            input_splits = inp[j:j+num_splits]
                            input_splits = torch.reshape(input_splits, (1, 3, -1, 64, 64))
                            input_splits = np.stack(input_splits[:num_splits], axis=1)[None][None]
                            input_splits = torch.from_numpy(input_splits)
                            
                            optimizer.zero_grad()

                            outputs = model(input_splits)

                            # print(outputs.shape, target_probs.shape)
                            loss = loss_function(outputs, target_probs)
                            loss.backward()

                            optimizer.step()

                            pred_class = torch.argmax(outputs, dim=1)

                            inp_pred.append(pred_class.item())
                            j += 1

                        pred_counter = Counter(inp_pred)
                        pred_label = pred_counter.most_common(1)[0][0]

                        print(f'pred_label {pred_label}, label {target}')

                        batch_total += 1
                        batch_corr += (pred_label == labels[e]).sum().item()

                        t_inner.set_postfix(batch_accuracy=batch_corr/batch_total)

                val_correct_predictions += batch_corr
                val_total_predictions += batch_total

            val_accuracy = val_correct_predictions / val_total_predictions
            t_outer.set_postfix(model_accuracy=val_accuracy, epoch=epoch)

            checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_epoch_{epoch}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'val_accuracy': val_accuracy
            }, checkpoint_path)


            results[model_name] = val_accuracy

Model: mvitv2:   0%|          | 0/13 [00:00<?, ?it/s]



pred_label 7, label 7




pred_label 7, label 6




pred_label 7, label 7


Model: mvitv2:   8%|▊         | 1/13 [13:05<2:37:03, 785.26s/it]

pred_label 7, label 9




pred_label 7, label 1




pred_label 7, label 5




pred_label 7, label 4


Model: mvitv2:  15%|█▌        | 2/13 [26:20<2:25:04, 791.29s/it]

pred_label 4, label 5




pred_label 5, label 1




pred_label 5, label 0




pred_label 5, label 0


Model: mvitv2:  23%|██▎       | 3/13 [39:55<2:13:39, 801.96s/it]

pred_label 0, label 0




pred_label 0, label 3




pred_label 0, label 2




pred_label 0, label 8


Model: mvitv2:  31%|███       | 4/13 [53:50<2:02:17, 815.23s/it]

pred_label 0, label 3




pred_label 0, label 5




pred_label 5, label 4




pred_label 4, label 9


Model: mvitv2:  38%|███▊      | 5/13 [1:07:32<1:48:59, 817.40s/it]

pred_label 5, label 3




pred_label 3, label 2




pred_label 3, label 6




pred_label 4, label 4


Model: mvitv2:  46%|████▌     | 6/13 [1:21:00<1:34:58, 814.13s/it]

pred_label 4, label 7




pred_label 4, label 8




pred_label 4, label 4




pred_label 4, label 1


Model: mvitv2:  54%|█████▍    | 7/13 [1:34:16<1:20:49, 808.31s/it]

pred_label 4, label 8




pred_label 4, label 6




pred_label 4, label 8




pred_label 8, label 5


Model: mvitv2:  62%|██████▏   | 8/13 [1:47:12<1:06:30, 798.13s/it]

pred_label 8, label 9




pred_label 8, label 6




pred_label 8, label 0




pred_label 8, label 8


Model: mvitv2:  69%|██████▉   | 9/13 [1:59:59<52:33, 788.38s/it]  

pred_label 8, label 1




pred_label 8, label 5




pred_label 8, label 6




pred_label 6, label 0


Model: mvitv2:  77%|███████▋  | 10/13 [2:13:00<39:18, 786.06s/it]

pred_label 6, label 7




pred_label 6, label 9




pred_label 6, label 2




pred_label 6, label 3


Model: mvitv2:  85%|████████▍ | 11/13 [2:25:58<26:07, 783.58s/it]

pred_label 6, label 1




pred_label 1, label 4




pred_label 1, label 2




pred_label 2, label 2


Model: mvitv2:  92%|█████████▏| 12/13 [2:38:44<12:58, 778.31s/it]

pred_label 2, label 7




pred_label 2, label 9


Model: mvitv2: 100%|██████████| 13/13 [2:45:19<00:00, 763.01s/it]

pred_label 2, label 3



Model: mvitv2:   0%|          | 0/13 [00:00<?, ?it/s]

pred_label 2, label 4




pred_label 4, label 7




pred_label 7, label 0


Model: mvitv2:   8%|▊         | 1/13 [13:23<2:40:36, 803.01s/it]

pred_label 7, label 6




pred_label 7, label 7




pred_label 7, label 8




pred_label 7, label 0


Model: mvitv2:  15%|█▌        | 2/13 [26:49<2:27:38, 805.33s/it]

pred_label 7, label 1




pred_label 7, label 2




pred_label 2, label 4




pred_label 4, label 8


Model: mvitv2:  23%|██▎       | 3/13 [39:56<2:12:46, 796.64s/it]

pred_label 8, label 9




pred_label 2, label 2




pred_label 2, label 6




pred_label 2, label 5


Model: mvitv2:  31%|███       | 4/13 [52:44<1:57:47, 785.28s/it]

pred_label 2, label 9




pred_label 9, label 4




pred_label 4, label 1




pred_label 4, label 3


Model: mvitv2:  38%|███▊      | 5/13 [1:24:39<2:39:02, 1192.86s/it]

pred_label 4, label 3




pred_label 3, label 3




pred_label 3, label 2




pred_label 3, label 8


Model: mvitv2:  46%|████▌     | 6/13 [1:38:37<2:05:06, 1072.29s/it]

pred_label 3, label 3




pred_label 3, label 9




pred_label 3, label 6




pred_label 3, label 4


Model: mvitv2:  54%|█████▍    | 7/13 [1:52:23<1:39:10, 991.80s/it] 

pred_label 3, label 7




pred_label 3, label 0




pred_label 3, label 8




pred_label 8, label 5


Model: mvitv2:  62%|██████▏   | 8/13 [2:05:38<1:17:24, 928.93s/it]

pred_label 8, label 9




pred_label 9, label 1




pred_label 1, label 2




pred_label 1, label 5


Model: mvitv2:  69%|██████▉   | 9/13 [2:19:03<59:21, 890.27s/it]  

pred_label 1, label 0




pred_label 1, label 4




pred_label 1, label 9




pred_label 9, label 7


Model: mvitv2:  77%|███████▋  | 10/13 [2:33:15<43:55, 878.52s/it]

pred_label 9, label 3




pred_label 9, label 6




pred_label 9, label 8




pred_label 5, label 5


Model: mvitv2:  85%|████████▍ | 11/13 [2:47:29<29:02, 871.01s/it]

pred_label 5, label 7




pred_label 7, label 1




pred_label 7, label 6




pred_label 6, label 2


Model: mvitv2:  92%|█████████▏| 12/13 [3:02:28<14:39, 879.36s/it]

pred_label 1, label 1




pred_label 1, label 0


Model: mvitv2: 100%|██████████| 13/13 [3:09:26<00:00, 874.35s/it]

pred_label 1, label 5



Model: mvitv2:   0%|          | 0/13 [00:00<?, ?it/s]

pred_label 5, label 5




pred_label 5, label 4




pred_label 5, label 7


Model: mvitv2:   8%|▊         | 1/13 [14:56<2:59:14, 896.25s/it]

pred_label 5, label 2




pred_label 5, label 9




pred_label 5, label 3




pred_label 7, label 7


Model: mvitv2:  15%|█▌        | 2/13 [29:26<2:41:30, 880.99s/it]

pred_label 7, label 2




pred_label 2, label 5




pred_label 5, label 9




pred_label 5, label 4


Model: mvitv2:  23%|██▎       | 3/13 [43:42<2:24:57, 869.72s/it]

pred_label 5, label 4




pred_label 5, label 5




pred_label 5, label 0




pred_label 5, label 4


Model: mvitv2:  31%|███       | 4/13 [58:22<2:11:03, 873.73s/it]

pred_label 4, label 4




pred_label 4, label 9




pred_label 4, label 8




pred_label 4, label 8


Model: mvitv2:  38%|███▊      | 5/13 [1:12:21<1:54:49, 861.25s/it]

pred_label 4, label 3




pred_label 4, label 6




pred_label 4, label 6




pred_label 4, label 0


Model: mvitv2:  46%|████▌     | 6/13 [1:27:21<1:41:59, 874.18s/it]

pred_label 4, label 7




pred_label 6, label 6




pred_label 6, label 7




pred_label 7, label 1


Model: mvitv2:  54%|█████▍    | 7/13 [1:42:27<1:28:27, 884.57s/it]

pred_label 7, label 8




pred_label 8, label 8




pred_label 8, label 5




pred_label 8, label 0


Model: mvitv2:  62%|██████▏   | 8/13 [1:58:24<1:15:39, 907.87s/it]

pred_label 8, label 9




pred_label 8, label 2




pred_label 8, label 1




pred_label 8, label 2


Model: mvitv2:  69%|██████▉   | 9/13 [2:20:17<1:08:58, 1034.54s/it]

pred_label 2, label 6




pred_label 6, label 8




pred_label 8, label 7




pred_label 8, label 3


Model: mvitv2:  77%|███████▋  | 10/13 [2:35:36<49:56, 998.80s/it]  

pred_label 8, label 5




pred_label 6, label 6




pred_label 6, label 1




pred_label 6, label 1


Model: mvitv2:  85%|████████▍ | 11/13 [3:16:04<47:52, 1436.17s/it]

pred_label 1, label 3




pred_label 1, label 0




pred_label 1, label 2




pred_label 1, label 3


Model: mvitv2:  92%|█████████▏| 12/13 [3:34:06<22:08, 1328.29s/it]

pred_label 3, label 9




pred_label 1, label 1


Model: mvitv2: 100%|██████████| 13/13 [3:43:05<00:00, 1029.64s/it]

pred_label 1, label 0



Model: mvitv2:   0%|          | 0/13 [00:00<?, ?it/s]

pred_label 1, label 3




pred_label 3, label 0




pred_label 0, label 4


Model: mvitv2:   8%|▊         | 1/13 [14:44<2:56:52, 884.34s/it]

pred_label 3, label 3




pred_label 3, label 6




pred_label 3, label 6




pred_label 6, label 8


Model: mvitv2:  15%|█▌        | 2/13 [28:55<2:38:32, 864.77s/it]

pred_label 6, label 7




pred_label 6, label 0




pred_label 0, label 9




pred_label 0, label 9


Model: mvitv2:  23%|██▎       | 3/13 [44:11<2:28:02, 888.20s/it]

pred_label 9, label 8




pred_label 9, label 2




pred_label 9, label 7




pred_label 7, label 7


Model: mvitv2:  31%|███       | 4/13 [58:30<2:11:30, 876.69s/it]

pred_label 7, label 8




pred_label 8, label 5




pred_label 8, label 3




pred_label 8, label 4


Model: mvitv2:  38%|███▊      | 5/13 [1:12:39<1:55:34, 866.80s/it]

pred_label 8, label 1




pred_label 8, label 2




pred_label 8, label 8




pred_label 8, label 1


Model: mvitv2:  46%|████▌     | 6/13 [1:27:21<1:41:41, 871.71s/it]

pred_label 8, label 0




pred_label 8, label 5




pred_label 1, label 1




pred_label 1, label 7


Model: mvitv2:  54%|█████▍    | 7/13 [1:44:22<1:32:03, 920.56s/it]

pred_label 1, label 9




pred_label 7, label 7




pred_label 7, label 2




pred_label 7, label 8


Model: mvitv2:  62%|██████▏   | 8/13 [2:00:03<1:17:16, 927.23s/it]

pred_label 7, label 4




pred_label 7, label 9




pred_label 9, label 4




pred_label 9, label 9


Model: mvitv2:  69%|██████▉   | 9/13 [2:17:25<1:04:12, 963.17s/it]

pred_label 9, label 5




pred_label 9, label 2




pred_label 9, label 1




pred_label 9, label 6


Model: mvitv2:  77%|███████▋  | 10/13 [2:31:18<46:08, 922.95s/it] 

pred_label 9, label 0




pred_label 9, label 3




pred_label 9, label 3




pred_label 3, label 5


Model: mvitv2:  85%|████████▍ | 11/13 [2:50:41<33:12, 996.26s/it]

pred_label 5, label 6




pred_label 5, label 5
