In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/contradictory-my-dear-watson/sample_submission.csv
/kaggle/input/contradictory-my-dear-watson/train.csv
/kaggle/input/contradictory-my-dear-watson/test.csv


# Import needed libraries

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

# Device agnostic code

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"default device set to {device}")

default device set to cpu


In [4]:
train_df = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/train.csv")
train_df.head(5)

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1


In [5]:
train_df.drop(columns=["id", "lang_abv", "language"], inplace=True)

In [6]:
train_df.sample(5)

Unnamed: 0,premise,hypothesis,label
4105,Unaona mbegu ya ndani ya Dominant Mendelian il...,Mazingira hayakuwa sahihi kila wakati kwa jeni...,0
3950,well the first thing for me is i wonder i see ...,"Talking about privacy is a complicated topic, ...",0
6767,yeah i do remember that and uh i remember as a...,I watched the Ed Sullivan Show when I was ten-...,1
5942,Những thứ khác không làm cho người tiêu dùng đ...,Các quả bóng chày không đủ trắng.,1
7241,This popular show spawned the aquatic show at ...,Bellagio's water display is now just as well l...,1


In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

In [8]:
class CustomDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: object, max_length: int):
        self.dataset = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        premise = self.dataset["premise"].iloc[idx]
        hypothesis = self.dataset["hypothesis"].iloc[idx]
        
        token_dict = self.tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt")
    
        return {
            "input_ids": token_dict["input_ids"].flatten(),
            "token_type_ids": token_dict["token_type_ids"].flatten(),
            "attention_mask": token_dict["attention_mask"].flatten(),
            "label": torch.tensor(self.dataset["label"].iloc[idx], dtype=torch.long)
        }
    
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = [batch_item["input_ids"] for batch_item in batch]
        batch_type_ids = [batch_item["token_type_ids"] for batch_item in batch]
        batch_attention_masks = [batch_item["attention_mask"] for batch_item in batch]
        
        batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
        batch_type_ids  = pad_sequence(batch_type_ids, batch_first=True, padding_value=0) # will get ignored by the attention mask when going through the model; very important
        batch_attention_masks = pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
        
        return {
            "input_ids": batch_input_ids,
            "token_type_ids": batch_type_ids,
            "attention_mask": batch_attention_masks,
            "labels": torch.stack([batch_item["label"] for batch_item in batch], dim=0)
        }
        

In [9]:
max_length = max(max([len(premise) for premise in train_df["premise"]]), max([len(hypothesis) for hypothesis in train_df["hypothesis"]]))
train_split, val_split = train_test_split(train_df, test_size=0.25, shuffle=True)
train_dataset = CustomDataset(df=train_split, tokenizer=tokenizer, max_length=max_length)
val_dataset = CustomDataset(df=val_split, tokenizer=tokenizer, max_length=max_length)

In [10]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, generator=torch.Generator(device=device), collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=True, generator=torch.Generator(device=device), collate_fn=val_dataset.collate_fn)

In [11]:
next(iter(val_dataloader))

{'input_ids': tensor([[   101,  11160,  14204,  10894,  10115,    112,    188,  10529,  27904,
           10336,    119,    102,  14535,  16032,  19556,  16247,    117,  24682,
           10894,  10529,  19556,  10681, 110353,    119,    102,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0],
         [   101,    511,  30318,  69373,  33190,    546,  44890,  91464,  10960,
           12189,  24373,  10696,  66806,  10405,  10868,  10387,  10752,  40383,
           10122,  86473,  10913,  27796,  12753,  53156,  12673,    136,    102,
             526, 