In [19]:
from transformers import AutoTokenizer
from datasets import Dataset as TransformersDataset, load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader,TensorDataset

from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import StandardScaler, scale, MinMaxScaler
# from sklearn.decomposition import TruncatedSVD

import tensorflow_datasets as tfds

import numpy as np
import pandas as pd

import math
from einops import rearrange

import os
from os import path, walk
import itertools
import json
import random
import datetime
import gc
import time
import re

from tqdm import tqdm

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
SEED = 333
def seedBasic(seed=SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    
def seedTorch(seed=SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + torch 
def seedEverything(seed=SEED):
    seedBasic(seed)
    seedTorch(seed)

SEED = 333
def seedBasic(seed=SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    
def seedTorch(seed=SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + torch 
def seedEverything(seed=SEED):
    seedBasic(seed)
    seedTorch(seed)

seedEverything()

In [None]:
tokenizer = AutoTokenizer.from_pretrained('t5-small')

In [None]:
class UL2Dataset():
    def __init__(self,
                 text_list,
                 mu=3,
                 d=0.2,
                 probabilities={"s": 0.2, "b": 0.2, "c": 0.5},
                 tokenizer=None):
        """
        tokens_list: List of list of tokens
        mu: mean span length
        d: density of the corruption used in the bidirectional method
        probabilities: probabilities for the corruption method
            s: corrupt just the first tokens
            b: corrupt uniformly (bidirectional)
            c: corrupt the last tokens (causal)
        """
        self.text_list = text_list
        self.mu = mu
        self.d = d
        self.probabilities = probabilities
        
        self.tokenizer = tokenizer
        self.extra_ids = tokenizer.additional_special_tokens_ids #list(tokenizer.get_added_vocab().values())
        
    
    def _s_noise(self, tokens):
        l = len(tokens)
        extra_id = self.extra_ids[0]
        span_length = random.randint(self.mu, l//2) if l//2 > self.mu else random.randint(1, self.mu)
        
        target_tokens = [self.extra_ids[1]] + tokens[:span_length] + [extra_id]
        tokens = [extra_id] + tokens[span_length:]
        
        return tokens, target_tokens
    
    def _c_noise(self, tokens):
        l = len(tokens)
        extra_id = self.extra_ids[0]
        span_length = random.randint(self.mu, l//2) if l//2 > self.mu else random.randint(1, self.mu)
        
        target_tokens = [extra_id] + tokens[-span_length:] + [self.extra_ids[1]]
        tokens = tokens[:-span_length] + [extra_id]
        
        return tokens, target_tokens
    
    def _b_noise(self, tokens):
        l = len(tokens)
        n = max(int(l * self.d), 1)
        span_len_list = []
        space_len_list = []
        max_span = max(n//4, self.mu)
        
        while sum(span_len_list) < n:
            span_len_list.append(random.randint(1, min(max_span, n - sum(span_len_list))))
        
        for j in range(len(span_len_list)):
            space_len_list.append(random.randint(0,(l - n)-sum(space_len_list)-len(span_len_list)+j+1))
        
        reste = (l - n)-sum(space_len_list)
        if reste > 0:
            space_len_list.append(reste)
        
        target_tokens = []
        x_tokens = []
        if len(space_len_list) > 1:
            kk = 1
            while kk < len(space_len_list):
                if space_len_list[kk] == 0:
                    space_len_list.pop(kk)
                    span_len_list[kk-1] += span_len_list.pop(kk)
                    kk -= 1
                kk += 1
        
        for k in range(len(space_len_list)):
            a = sum(space_len_list[:k])
            a += sum(span_len_list[:k])
            b = a + space_len_list[k]
            x_tokens.append(tokens[a:b])
            if k < len(span_len_list):
                x_tokens[len(x_tokens)-1] = x_tokens[len(x_tokens)-1] + [self.extra_ids[k]]
                target_tokens.append([self.extra_ids[k]] + tokens[b: b + span_len_list[k]])
            else:
                h = len(target_tokens) - 1
                target_tokens[h] = target_tokens[h] + [self.extra_ids[k]]
          
        
        return [item for sublist in x_tokens for item in sublist], [item for sublist in target_tokens for item in sublist]

    def __getitem__(self, index):
        text = self.text_list[index]
        tokenizer_result = tokenizer(text, padding=False, truncation=True, add_special_tokens=False)
        tokens = tokenizer_result['input_ids']
        
        func_args = [
            self._s_noise,
            self._b_noise,
            self._c_noise
        ]

        # Choose random function
        func, = random.choices(func_args,
                                       weights=[
                                           self.probabilities["s"],
                                           self.probabilities["b"],
                                           self.probabilities["c"]
                                       ])

        # Call it
        x, y = func(tokens)
        
        x_o = tokenizer.decode(x)
        y_o = tokenizer.decode(y)
        
        return x_o, y_o


    def __len__(self):
        return len(self.text_list)

In [None]:
d = UL2Dataset(
    ["1n 2b 3g 4k 5i 6g 7i 8p 9o 11 22 55 44 77 88 99 33"],
    probabilities={"s": 0.2, "b": 0.3, "c": 0.5},
    tokenizer=tokenizer)
for x, y in d:
    print(x)
    print(y)

In [None]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [None]:
class T5LayerNorm(nn.Module):
    def __init__(self, config):
        """
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(config["d_model"]))
        self.variance_epsilon = config["norm_eps"]

    def forward(self, x):

        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
        # half-precision inputs is done in fp32

        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            x = x.to(self.weight.dtype)

        return self.weight * x

In [None]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config["d_model"], config["d_ff"], bias=False)
        self.wi_1 = nn.Linear(config["d_model"], config["d_ff"], bias=False)
        self.wo = nn.Linear(config["d_ff"], config["d_model"], bias=False)
        self.dropout = nn.Dropout(config["dropout_rate"])
        self.act = nn.SiLU()

    def forward(self, x):
        x_gelu = self.act(self.wi_0(x))
        x_linear = self.wi_1(x)
        x = x_gelu * x_linear
        x = self.dropout(x)
        x = self.wo(x)
        return x

In [None]:
class T5RelativePositionBias(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.scale = config["dim_head"] ** -0.5
        self.causal = config["causal"]
        self.num_buckets = config["num_buckets"]
        self.max_distance = config["max_distance"]
        self.relative_attention_bias = nn.Embedding(config["num_buckets"], config["heads"])

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, qk_dots):
        i, j, device = *qk_dots.shape[-2:], qk_dots.device
        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = k_pos[None, :] - q_pos[:, None]
        rp_bucket = self._relative_position_bucket(
            rel_pos, 
            causal = self.causal, 
            num_buckets = self.num_buckets, 
            max_distance = self.max_distance
        )
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> h i j')
        return qk_dots + (bias * self.scale)

In [None]:
class T5SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        inner_dim = config["dim_head"] * config["heads"]
        self.heads = config["heads"]
        self.scale = config["dim_head"] ** -0.5
        self.causal = config["causal"]

        self.to_q = nn.Linear(config["d_model"], inner_dim, bias = False)
        self.to_k = nn.Linear(config["d_model"], inner_dim, bias = False)
        self.to_v = nn.Linear(config["d_model"], inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, config["d_model"])

        self.relative_position_bias = T5RelativePositionBias(config)

        self.dropout = nn.Dropout(config["dropout_rate"])

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        sim = self.relative_position_bias(sim)
        # mask

        mask_value = -torch.finfo(sim.dtype).max

        if mask is not None:
            sim = sim.masked_fill_(~mask, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        
        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        
        # combine heads and linear output

        return self.to_out(out)

In [None]:
class T5CrossAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        inner_dim = config["dim_head"] * config["heads"]
        context_dim = default(config["context_dim"], config["d_model"])

        self.heads = config["heads"]
        self.scale = config["dim_head"] ** -0.5

        self.to_q = nn.Linear(config["d_model"], inner_dim, bias = False)
        self.to_k = nn.Linear(config["context_dim"], inner_dim, bias = False)
        self.to_v = nn.Linear(config["context_dim"], inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, config["d_model"])

        self.dropout = nn.Dropout(config["dropout_rate"])

    def forward(self, x, context, mask = None, context_mask = None):
        b, n, _, h = *x.shape, self.heads

        kv_input = default(context, x)

        q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

        # mask

        mask_value = -torch.finfo(sim.dtype).max

        if mask is not None:
            sim = sim.masked_fill_(~mask, mask_value)

        if context_mask is not None:
            sim = sim.masked_fill_(~context_mask[:, None, :], mask_value)

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        
        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')
               
        # combine heads and linear output

        return self.to_out(out)


In [None]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, config):
        super(SublayerConnection, self).__init__()
        self.norm = T5LayerNorm(config)
        self.dropout = nn.Dropout(config["dropout_rate"])

    def forward(self, sublayer, x, **kwargs):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x), **kwargs))

In [None]:
class T5EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = T5SelfAttention(config)
        self.mlp = FeedForward(config)
        
        self.sublayer1 = SublayerConnection(config)
        self.sublayer2 = SublayerConnection(config)

    def forward(self, x, mask = None):
        x = self.sublayer1(self.att, x, mask = mask)
        
        x = self.sublayer1(self.mlp, x)
        
        return x

In [None]:
class T5Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        # self.token_emb = nn.Embedding(config["num_tokens"], config["d_model"])

        self.layers = nn.ModuleList([T5EncoderLayer(config) for _ in range(config["depth"])])

        self.final_norm = T5LayerNorm(config)

    def forward(self, x, mask = None):
        # x = self.token_emb(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.final_norm(x)

        return x

In [None]:
class T5DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = T5SelfAttention(config)
        self.cross_attn = T5CrossAttention(config)
        self.mlp = FeedForward(config)
        
        self.sublayer1 = SublayerConnection(config)
        self.sublayer2 = SublayerConnection(config)
        self.sublayer3 = SublayerConnection(config)

    def forward(self, x, context, mask = None, context_mask = None):
        x = self.sublayer1(self.att, x, mask = mask)
        
        x = self.sublayer2(self.cross_attn, x, context = context, mask = mask, context_mask = context_mask)
        
        x = self.sublayer3(self.mlp, x)
        
        return x

In [None]:
class T5Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config["num_tokens"], config["d_model"])
        
        self.layers = nn.ModuleList([T5DecoderLayer(config) for _ in range(config["depth"])])
        
        self.final_norm = T5LayerNorm(config)

    def forward(self, x, context, mask = None, context_mask = None):
        x = self.token_emb(x)
        
        for layer in self.layers:
            x = layer(x, context, mask, context_mask)

        x = self.final_norm(x)

        return x

In [None]:
class T5(nn.Module):
    def __init__(self, encoder_config, decoder_config, tie_token_emb=True):
        super().__init__()
        self.token_emb = nn.Embedding(encoder_config["num_tokens"], encoder_config["d_model"])
        self.encoder = T5Encoder(encoder_config)
        
        self.decoder = T5Decoder(decoder_config)

        self.to_logits = nn.Linear(decoder_config["d_model"], decoder_config["num_tokens"], bias=False)

        # tie weights
        if tie_token_emb:
            self.token_emb.weight = self.decoder.token_emb.weight

    def forward(self, src, tgt, mask = None, context_mask = None):
        x = self.token_emb(src)
        x = self.encoder(x, mask = mask)
        x = self.decoder(tgt, x, mask = mask, context_mask = context_mask)
        x = self.to_logits(x)
        return x

In [None]:
def get_lr_scheduler(optimizer, batch_size = 64, last_epoch=-1):
    lr_start   = 0.00001
    lr_max     = 0.00003 * batch_size
    lr_min     = 0.000001
    lr_ramp_ep = 30
    lr_sus_ep  = 0
    lr_decay   = 0.9
    def lrfn(epoch):
        if epoch < lr_ramp_ep: lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
        elif epoch < lr_ramp_ep + lr_sus_ep: lr = lr_max
        else: lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
        return lr
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lrfn, last_epoch=last_epoch, verbose=False)
    return lr_scheduler

In [None]:
def get_data_loader(ds, ul2_arg, batch_size=64):
    ul2_data = UL2Dataset(ds, tokenizer=tokenizer, mu=ul2_arg["mu"], d=ul2_arg["d"], probabilities=ul2_arg["probabilities"])
    inputs = []
    targets = []
    for j in range(len(ul2_data)):
        x, y = ul2_data[j]
        inputs.append(x)
        targets.append(y)
    t = tokenizer([*inputs, *targets], padding=True, truncation=True)
    
    ds_set = TensorDataset(
        torch.tensor(t['input_ids'][:len(t['input_ids'])//2], dtype=torch.long),
        torch.tensor(t['attention_mask'][:len(t['input_ids'])//2], dtype=torch.long),
        torch.tensor(t['input_ids'][len(t['input_ids'])//2:], dtype=torch.long))
    
    return DataLoader(ds_set, batch_size=batch_size)

In [None]:
def validate(model, val_loader):
    if not isinstance(model, nn.DataParallel):
        model = nn.DataParallel(model)
    
    model = model.to(device)
    model.eval()

    loss_list = []
    loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for i, (x, att, labels) in enumerate(val_loader):

            x = x.to(device, dtype=torch.long)
            att = att.to(device, dtype=torch.long)
            labels = labels.to(device, dtype=torch.long)
            
            y = model(x, att)
            loss = loss_fn(y.reshape((-1, y.shape[-1])), labels.reshape(-1))

            loss_list.append(loss.to('cpu').detach())


    loss = np.mean(loss_list)
    
    return loss

In [None]:
def train(model, args, checkpoint=None):
    torch.set_grad_enabled(True)

    start_time = time.time()

    if not isinstance(model, nn.DataParallel):
        model = nn.DataParallel(model)

    model = model.to(device)
    
    # Set up the optimizer
    trainables = [p for p in model.parameters() if p.requires_grad]
    print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in model.parameters()) / 1e6))
    print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6))

    if args["optimizer"] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"], weight_decay=5e-7, betas=(0.95, 0.999))
    elif args["optimizer"] == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"], weight_decay=5e-7, amsgrad=True)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args["lr"], momentum=0.9, nesterov=True, weight_decay=5e-7)
    
    last_epoch = -1
    if checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        last_epoch = checkpoint["chunk"]
    
    if args["scheduler"] == "LambdaLR":
        scheduler = get_lr_scheduler(optimizer, batch_size = args["batch_size"] * args["NUM_ACCUMULATION_STEPS"], last_epoch=last_epoch)
    elif args["scheduler"] == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args["cosin_T_max"], last_epoch=last_epoch)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args["lrscheduler_start"], 1000, args["lrscheduler_step"])), gamma=args["lrscheduler_decay"], last_epoch=last_epoch)
    
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    
    
    chunksize = 10 ** 5
    with pd.read_csv("/kaggle/input/nlp-dataset/books.csv", chunksize=chunksize) as reader:
        for i, chunk in enumerate(reader):
            if last_epoch >= i:
                continue
            chunk = chunk[["text"]]
            train_chunk, test_chunk = train_test_split(chunk, test_size=0.1, random_state=42)
            batch_size = args['batch_size']
            ul2_arg = {
                "mu": 3,
                "d": 0.2,
                "probabilities": {"s": 0.2, "b": 0.3, "c": 0.5}
            }
            train_loader =  get_data_loader(
                train_chunk["text"].tolist(),
                ul2_arg=ul2_arg,
                batch_size=batch_size)
            test_loader = get_data_loader(
                test_chunk["text"].tolist(),
                ul2_arg=ul2_arg,
                batch_size=batch_size)
            
            
            begin_time = time.time()
            model.train()
            
            loss_train = []

            for k, (x, att, labels) in enumerate(train_loader):
                x = x.to(device, dtype=torch.long)
                att = att.to(device, dtype=torch.long)
                labels = labels.to(device, dtype=torch.long)

                y = model(x, att)

                loss = loss_fn(y.reshape((-1, y.shape[-1])), labels.reshape(-1))
                
                loss = loss / args["NUM_ACCUMULATION_STEPS"]

                loss.backward()
                
                if ((k + 1) % args["NUM_ACCUMULATION_STEPS"] == 0) or (k + 1 == len(train_loader)):
                    optimizer.step()
                    optimizer.zero_grad()

                loss_train.append(loss.clone().detach().to('cpu'))


            train_loss = np.mean(loss_train)
            val_loss = validate(model, test_loader)
            lr = scheduler.get_last_lr()[0]
            
            del train_loader, test_loader, train_chunk, test_chunk
            gc.collect()

            print(f"chunk: {i}, lr: {lr:.8f}, train loss: {train_loss:.6f}, val loss: {val_loss:.6f}")

            scheduler.step()
            
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'chunk': i
            }, 'model.pth')
            
            if time.time() - start_time > 60*60*9:
                break

In [None]:
args = {
    "lr": 0.001,
    "lrscheduler_start": 15,
    "lrscheduler_step": 10,
    "lrscheduler_decay": 0.5,
    "warmup": True,
    "optimizer": ["adam", "adamw", "sgd"][1],
    "scheduler": ["LambdaLR", "cosine"][0],
    "batch_size": 8,
    "NUM_ACCUMULATION_STEPS": 1024//8
}

d_model =  512
depth = 8
encoder_config = {
    "d_model": d_model,
    "d_ff": int(d_model*2.5),
    "dropout_rate": 0,
    "causal": False,
    "num_buckets": 32,
    "max_distance": 128,
    "heads": 12,
    "dim_head": 64,
    "depth": depth,
    "num_tokens": tokenizer.vocab_size,
    "norm_eps": 1e-6
}
decoder_config = {
    "d_model": d_model,
    "d_ff": int(d_model*2.5),
    "dropout_rate": 0,
    "causal": False,
    "num_buckets": 32,
    "max_distance": 128,
    "heads": 12,
    "dim_head": 64,
    "context_dim": encoder_config["d_model"],
    "depth": depth,
    "num_tokens": tokenizer.vocab_size,
    "norm_eps": 1e-6
}

model = T5(encoder_config, decoder_config)

In [None]:
train(model, args, checkpoint = torch.load("/kaggle/input/t5model/model.pth"))