In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
import random

In [5]:
class NN(nn.Module):
    def __init__(self,):
        super().__init__()
        
        self.layer = nn.Linear(1, 1)
        
    def forward(self, x1, x2):
        return self.layer(x1 + x2)

In [7]:
model = NN()

In [13]:
x0 = torch.randn(1)

with torch.no_grad():
    x1 = model(x0, x0)
    x1 = x1.detach()
    
x2 = model(x1, x0)

In [14]:
loss = torch.square(x2 - x0)

In [15]:
loss.backward()

In [16]:
for name, param in model.named_parameters():
            print(name, param.grad is None)

layer.weight False
layer.bias False


In [1]:
from datasets import Dataset, disable_progress_bar
from datasets.utils.logging import set_verbosity_error
from itertools import cycle
import json
import gc
import torch
import numpy as np
import torch.distributed as dist
from typing import List

In [5]:
class RocStoryDatasetDDP:
    def __init__(self,
                 split, tokenizer_bert, tokenizer_cond, tokenizer_gen, max_sequence_len,
                 pos_begin: float = 0., pos_end: float = 0.67):
        self.split = split
        self.tokenizer_bert = tokenizer_bert
        self.tokenizer_cond = tokenizer_cond
        self.tokenizer_gen = tokenizer_gen
        self.max_sequence_len = max_sequence_len
        self.max_cond_len = max_sequence_len
        self.pos_begin = pos_begin
        self.pos_end = pos_end
        self.device_id = dist.get_rank() if torch.distributed.is_initialized() else 0
        self.total_device_number = dist.get_world_size() if torch.distributed.is_initialized() else 1
        self.epoch = 0

    def spilt_data_across_gpu(self, dt: List[str]):
        if self.split == "train":
            indexes = np.random.default_rng(seed=self.epoch).permutation(len(dt))
        else:
            indexes = np.arange(len(dt))
        
        start_ind = self.device_id * (len(dt) // self.total_device_number)
        end_ind = (self.device_id + 1) * (len(dt) // self.total_device_number)
        if (self.device_id + 1) == self.total_device_number:
            indexes = indexes[start_ind:]
        else:
            indexes = indexes[start_ind: end_ind]
        
        dt = [dt[i] for i in indexes]
        return dt
    

    def load_data(self, path):
        dt = []
        with open(path, "r") as file:
            for l in file:
                dt.append(l.strip())
        dt = self.spilt_data_across_gpu(dt)
        dt = Dataset.from_list([{"text": t} for t in dt])

        self.dt = dt.map(
            self.batch_preprocessing,
            batched=True,
            load_from_cache_file=False,
            num_proc=30,
            desc="Dataset tokenization",
            batch_size=1000,
        )
        self.dt = self.dt.with_format("pt", columns=["input_ids", "cond_ids", "input_mask", "cond_mask"])
        return self.dt

    def batch_preprocessing(self, batch):
        # Tokenize
        input_ids = self.tokenizer_cond(
            batch["text"],
            add_special_tokens=False,
            padding="max_length",
            truncation=True,
            max_length=self.max_sequence_len,
        )["input_ids"]

        # Random split
        batch_size = len(batch["text"])
        elem_counts = self.max_cond_len
        delimeter_poses = (
            (
                    np.random.rand(batch_size) *
                    (self.pos_end - self.pos_begin) + self.pos_begin
            ) * elem_counts
        ).astype(int)

        cond_ids_list = []
        input_ids_list = []
        for i, element_ids in enumerate(input_ids):
            cond_ids_list.append(element_ids[:delimeter_poses[i]])
            input_ids_list.append(element_ids[delimeter_poses[i]:])
        

        # Tokens decode
        texts_cond = self.tokenizer_bert.batch_decode(cond_ids_list, skip_special_tokens=True)
        texts_input = self.tokenizer_bert.batch_decode(input_ids_list, skip_special_tokens=True)

        # Text encode
        cond_ = self.tokenizer_cond(
            texts_cond,
            add_special_tokens=True,
            padding="max_length",
            truncation=True,
            max_length=self.max_sequence_len,
        )

        input_ = self.tokenizer_gen(
            texts_input,
            add_special_tokens=True,
            padding="max_length",
            truncation=True,
            max_length=self.max_sequence_len,
        )

        output = {
            "input_ids": input_["input_ids"],
            "cond_ids": cond_["input_ids"],
            "input_mask": input_["attention_mask"],
            "cond_mask": cond_["attention_mask"],
        }
        return output


    def clear_data(self):
        del self.dt
        gc.collect()


    def get_data(self):
        if self.split == "valid":
            while True:
                test_path = "/home/vmeshchaninov/nlp_models/data/rocstories/validation/data.txt"
                yield self.load_data(test_path)
        elif self.split == "train":
            while True:
                train_path = "/home/vmeshchaninov/nlp_models/data/rocstories/train/data.txt"
                yield self.load_data(train_path)
        else:
            raise Exception("Wrong data split")

In [2]:
from transformers import BertTokenizerFast

In [3]:
bert_cfg = "bert-base-uncased"
tokenizer_bert = BertTokenizerFast.from_pretrained(bert_cfg)

In [11]:
dt = RocStoryDatasetDDP("valid", tokenizer_bert, tokenizer_bert, tokenizer_bert, 32)

In [12]:
dt = next(dt.get_data())

Dataset tokenization (num_proc=30):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [15]:
tokenizer_bert.decode(dt[0]["input_ids"])

'[CLS] susy was a little bit apprehensive. so [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'