In [6]:
from torch.utils.data.dataset import T_co
notebook_idenifier = "gated_transformer"
model_identifier = "all_features"

# GatedTabTransformer pytorch implementation

In [7]:
import glob
import os
import wandb
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
import pandas as pd
import json
import gc
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from torch.utils.data import DataLoader,Dataset
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

import warnings
warnings.filterwarnings('ignore')

In [8]:
from torch import nn

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
def dropout_layers(layers, prob_survival):
    if prob_survival == 1:
        return layers

    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival

    # make sure at least one layer makes it
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False

    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    return layers


def shift(t, amount, mask = None):
    if amount == 0:
        return t
    return F.pad(t, (0, 0, amount, -amount), value = 0.)


class PreShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        if self.shifts == (0,):
            return self.fn(x, **kwargs)

        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)


class Attention(nn.Module):
    def __init__(self, dim_in, dim_out, dim_inner, causal = False):
        super().__init__()
        self.scale = dim_inner ** -0.5
        self.causal = causal

        self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim_out)

    def forward(self, x):
        device = x.device
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if self.causal:
            mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
            sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)

        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        return self.to_out(out)


class SpatialGatingUnit(nn.Module):
    def __init__(
        self,
        dim,
        dim_seq,
        causal = False,
        act = nn.Identity(),
        heads = 1,
        init_eps = 1e-3,
        circulant_matrix = False
    ):
        super().__init__()
        dim_out = dim // 2
        self.heads = heads
        self.causal = causal
        self.norm = nn.LayerNorm(dim_out)

        self.act = act

        # parameters

        if circulant_matrix:
            self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
            self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))

        self.circulant_matrix = circulant_matrix
        shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
        weight = torch.zeros(shape)

        self.weight = nn.Parameter(weight)
        init_eps /= dim_seq
        nn.init.uniform_(self.weight, -init_eps, init_eps)

        self.bias = nn.Parameter(torch.ones(heads, dim_seq))

    def forward(self, x, gate_res = None):
        device, n, h = x.device, x.shape[1], self.heads

        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)

        weight, bias = self.weight, self.bias

        if self.circulant_matrix:
            # build the circulant matrix

            dim_seq = weight.shape[-1]
            weight = F.pad(weight, (0, dim_seq), value = 0)
            weight = repeat(weight, '... n -> ... (r n)', r = dim_seq)
            weight = weight[:, :-dim_seq].reshape(h, dim_seq, 2 * dim_seq - 1)
            weight = weight[:, :, (dim_seq - 1):]

            # give circulant matrix absolute position awareness

            pos_x, pos_y = self.circulant_pos_x, self.circulant_pos_y
            weight = weight * rearrange(pos_x, 'h i -> h i ()') * rearrange(pos_y, 'h j -> h () j')

        if self.causal:
            weight, bias = weight[:, :n, :n], bias[:, :n]
            mask = torch.ones(weight.shape[-2:], device = device).triu_(1).bool()
            mask = rearrange(mask, 'i j -> () i j')
            weight = weight.masked_fill(mask, 0.)

        gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)

        gate = einsum('b h n d, h m n -> b h m d', gate, weight)
        gate = gate + rearrange(bias, 'h n -> () h n ()')

        gate = rearrange(gate, 'b h n d -> b n (h d)')

        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res


class gMLPBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_ff,
        seq_len,
        heads = 1,
        attn_dim = None,
        causal = False,
        act = nn.Identity(),
        circulant_matrix = False
    ):
        super().__init__()
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_ff),
            nn.GELU()
        )

        self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None

        self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix = circulant_matrix)
        self.proj_out = nn.Linear(dim_ff // 2, dim)

    def forward(self, x):
        gate_res = self.attn(x) if exists(self.attn) else None
        x = self.proj_in(x)
        x = self.sgu(x, gate_res = gate_res)
        x = self.proj_out(x)
        return x


class gMLP(nn.Module):
    def __init__(
        self,
        *,
        num_tokens = None,
        dim,
        depth,
        seq_len,
        heads = 1,
        ff_mult = 4,
        attn_dim = None,
        prob_survival = 1.,
        causal = False,
        circulant_matrix = False,
        shift_tokens = 0,
        act = nn.Identity()
    ):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'

        dim_ff = dim * ff_mult
        self.seq_len = seq_len
        self.prob_survival = prob_survival

        self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()

        token_shifts = tuple(range(0 if causal else -shift_tokens, shift_tokens + 1))
        self.layers = nn.ModuleList([Residual(PreNorm(dim, PreShiftTokens(token_shifts, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act, circulant_matrix = circulant_matrix)))) for i in range(depth)])

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim, 1)
        )

    def forward(self, x):
        x = self.to_embed(x)
        layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
        out = nn.Sequential(*layers)(x)
        return self.to_logits(out)


class gMLPClassification(nn.Module):
    def __init__(
        self,
        *,
        patch_width,
        seq_len,
        num_classes,
        dim,
        depth,
        heads = 1,
        ff_mult = 4,
        attn_dim = None,
        prob_survival = 1.
    ):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
        num_patches = (seq_len // patch_width)

        dim_ff = dim * ff_mult

        self.to_patch_embed = nn.Sequential(
            Rearrange('b (w p2) -> b (w) (p2)', p2 = patch_width),
            nn.Linear(patch_width, dim)
        )

        self.prob_survival = prob_survival

        self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = num_patches, attn_dim = attn_dim))) for i in range(depth)])

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        x = self.to_patch_embed(x)
        layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
        x = nn.Sequential(*layers)(x)
        return self.to_logits(x)


def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def pair(val):
    return (val, val) if not isinstance(val, tuple) else val

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class FeedForward(nn.Module):
    def __init__(self, dim, mult=4, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)


class HeadAttention(nn.Module):
    def __init__(
            self,
            dim,
            heads=8,
            dim_head=16,
            dropout=0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim=-1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=h)
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
        super().__init__()
        self.embeds = nn.Embedding(num_tokens, dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, HeadAttention(dim, heads=heads, dim_head=dim_head, dropout=attn_dropout))),
                Residual(PreNorm(dim, FeedForward(dim, dropout=ff_dropout))),
            ]))

    def forward(self, x):
        x = self.embeds(x)

        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        return x


class MLP(nn.Module):
    def __init__(self, dims, act=None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims_pairs) - 1)
            linear = nn.Linear(dim_in, dim_out)
            layers.append(linear)

            if is_last:
                continue

            act = default(act, nn.ReLU())
            layers.append(act)

        self.mlp = layers

    def forward(self, x):
        for i,lay in enumerate(self.mlp):
            x = lay(x)
        return x


class GatedTabTransformer(nn.Module):
    def __init__(
            self,
            *,
            categories,
            num_continuous,
            transformer_dim,
            transformer_depth,
            transformer_heads,
            transformer_dim_head=16,
            dim_out=1,
            mlp_depth=2,
            select_dim = 128,
            mlp_act=None,
            num_special_tokens=2,
            continuous_mean_std=None,
            attn_dropout=0.,
            ff_dropout=0.,
            gmlp_enabled=False,
            mlp_dimension=32,
    ):
        super().__init__()
        assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'

        # categories related calculations

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.num_unique_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position in the categories embedding table

        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value=num_special_tokens)
        categories_offset = categories_offset.cumsum(dim=-1)[:-1]
        self.register_buffer('categories_offset', categories_offset)

        # continuous

        if exists(continuous_mean_std):
            assert continuous_mean_std.shape == (num_continuous,
                                                 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
        self.register_buffer('continuous_mean_std', continuous_mean_std)

        self.norm = nn.LayerNorm(num_continuous)
        self.num_continuous = num_continuous

        # transformer

        self.transformer = Transformer(
            num_tokens=total_tokens,
            dim=transformer_dim,
            depth=transformer_depth,
            heads=transformer_heads,
            dim_head=transformer_dim_head,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout
        )

        # mlp to logits

        input_size = (transformer_dim * self.num_categories) + num_continuous

        self.encoder = nn.Linear(input_size,select_dim)
        self.decoder = nn.Linear(select_dim,input_size)

        if gmlp_enabled:
            self.mlp = gMLPClassification(
                patch_width=1,
                seq_len=input_size,
                num_classes=dim_out,
                dim=mlp_dimension,
                depth=mlp_depth
            )
        else:
            hidden_dimensions = []

            for i in range(mlp_depth):
                if mlp_dimension == -1:
                    hidden_dimensions.append((input_size // 8) * (2 ** (mlp_depth - i)))
                else:
                    hidden_dimensions.append(mlp_dimension)

            all_dimensions = [input_size, *hidden_dimensions, dim_out]
            self.mlp = MLP(all_dimensions, act=mlp_act)

    def forward(self, x_categ, x_cont=None):
        assert x_categ.shape[
                   -1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x = self.transformer(x_categ)

        flat_categ = x.flatten(1)

        if self.num_continuous != 0:
            assert x_cont.shape[
                       1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input'

            if exists(self.continuous_mean_std):
                mean, std = self.continuous_mean_std.unbind(dim=-1)
                x_cont = (x_cont - mean) / std

            normed_cont = self.norm(x_cont)
            x = torch.cat((flat_categ, normed_cont), dim=-1)
        else:
            x = flat_categ

        embed = self.encoder(x)
        x = self.decoder(embed)
        return self.mlp(x), embed

In [9]:
import torch

def amex_metric_pytorch(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:

    # convert dtypes to float64
    y_true = y_true.double()
    y_pred = y_pred.double()

    # count of positives and negatives
    n_pos = y_true.sum()
    n_neg = y_pred.shape[0] - n_pos

    # sorting by descring prediction values
    indices = torch.argsort(y_pred, dim=0, descending=True)
    preds, target = y_pred[indices], y_true[indices]

    # filter the top 4% by cumulative row weights
    weight = 20.0 - target * 19.0
    cum_norm_weight = (weight / weight.sum()).cumsum(dim=0)
    four_pct_filter = cum_norm_weight <= 0.04

    # default rate captured at 4%
    d = target[four_pct_filter].sum() / n_pos

    # weighted gini coefficient
    lorentz = (target / n_pos).cumsum(dim=0)
    gini = ((lorentz - cum_norm_weight) * weight).sum()

    # max weighted gini coefficient
    gini_max = 10 * n_neg * (1 - 19 / (n_pos + 20 * n_neg))

    # normalized weighted gini coefficient
    g = gini / gini_max

    return 0.5 * (g + d)

In [12]:
with open("D:\\Academic_PC\\Sem 7\\Machine Learning\\MiniProject\\data\\xgb_preprocessed\\metadata.json","r") as f0:
    config:dict = json.load(f0)

In [13]:
config["device"] = "gpu"
config["bs"] = 128
config["epoch"] = 20
config["train_shape"] = (367131, 485)
config["val_shape"] = (91782, 485)

In [14]:
wandb.init(config=config)

[34m[1mwandb[0m: Currently logged in as: [33mdevin-18[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
if config['device'] == 'cpu':
    device = "cpu"
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [16]:
#train = pd.read_csv("../data/xgb_preprocessed/val_rows.csv")
#target = pd.read_csv('../data/train_labels.csv').target.values
#test = pd.read_csv("../data/xgb_preprocessed/test_data.csv")

In [17]:
def update_metadata(path,data:dict):
    loaded_config = {}
    if os. path.exists(path):
        with open(path,"r") as f0:
            loaded_config = json.load(f0)

    for k,v in data.items():
        loaded_config[k] = v

    with open(path,"w") as f0:
        json.dump(loaded_config,f0)

In [18]:
gc.collect()

1314

In [19]:
len(config.get("cont_features")),len(config.get("cat_features"))

(359, 124)

In [20]:
class AmexDataset(Dataset):
    def __init__(self,csv_path, cat_headers=config.get("cat_features"),con_headers=config.get("cont_features"),target_col="target",bs=1):
        self.csv_path = csv_path
        self.bs = bs
        self.csv = pd.read_csv(csv_path,usecols=["customer_ID",target_col]+cat_headers+con_headers)
        self.cat_headers = cat_headers
        self.con_headers = con_headers
        self.target_col = target_col

    def __len__(self):
        return self.csv.shape[0]

    def __getitem__(self, index):
        cat_f = self.csv.loc[index,self.cat_headers]
        con_f = self.csv[self.con_headers]
        tar = self.csv[self.target_col]
        return torch.from_numpy(np.int64(cat_f.to_numpy())),torch.from_numpy(np.float64(con_f.to_numpy())),torch.from_numpy(np.int8(tar.to_numpy()))

In [21]:
train_data = AmexDataset("D:\\Academic_PC\\Sem 7\\Machine Learning\\MiniProject\\data\\xgb_preprocessed\\train_rows.csv",
                         cat_headers=config.get("cat_features"),
                         con_headers=config.get("cont_features"),shape=config["train_shape"])
val_data = AmexDataset("D:\\Academic_PC\\Sem 7\\Machine Learning\\MiniProject\\data\\xgb_preprocessed\\val_rows.csv",
                         cat_headers=config.get("cat_features"),
                         con_headers=config.get("cont_features"),shape=config["val_shape"])

In [22]:
model = GatedTabTransformer(
    categories = tuple(config.get("cat_groups")),      # tuple containing the number of unique values within each category
    num_continuous = len(config.get("cont_features")),                # number of continuous values
    transformer_dim = 8,               # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    select_dim=128,
    transformer_depth = 6,              # depth, paper recommended 6
    transformer_heads = 4,              # heads, paper recommends 8
    attn_dropout = 0.2,                 # post-attention dropout
    ff_dropout = 0.2,                   # feed forward dropout
    mlp_act = nn.LeakyReLU(0),          # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
    mlp_depth=4,                        # mlp hidden layers depth
    mlp_dimension=16,                   # dimension of mlp layers
    gmlp_enabled=True                   # gmlp or standard mlp
)

In [23]:
pred, embedding = model(torch.randint(1,5,(32,len(config.get("cat_groups")))),torch.randn(32,len(config.get("cont_features"))))

In [24]:
pred.shape

torch.Size([32, 1])

In [25]:
embedding.shape

torch.Size([32, 128])

In [26]:
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
model.to(device=device)

GatedTabTransformer(
  (norm): LayerNorm((359,), eps=1e-05, elementwise_affine=True)
  (transformer): Transformer(
    (embeds): Embedding(5458, 8)
    (layers): ModuleList(
      (0): ModuleList(
        (0): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
            (fn): HeadAttention(
              (to_qkv): Linear(in_features=8, out_features=192, bias=False)
              (to_out): Linear(in_features=64, out_features=8, bias=True)
              (dropout): Dropout(p=0.2, inplace=False)
            )
          )
        )
        (1): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
            (fn): FeedForward(
              (net): Sequential(
                (0): Linear(in_features=8, out_features=64, bias=True)
                (1): GEGLU()
                (2): Dropout(p=0.2, inplace=False)
                (3): Linear(in_features=32, out_features=8, bias=True)
   

In [27]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i in tqdm(range(len(train_data)),desc="Training",total=(config["train_shape"][0]//config["bs"])+1):
        cat_b,con_b,tar = train_data[i]

        cat_b = cat_b.long().to(device)
        con_b = con_b.float().to(device)
        tar = tar.long().to(device)
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, _ = model(cat_b,con_b)

        print(outputs.size(),tar.size())
        # Compute the loss and its gradients
        loss = loss_fn(outputs,tar)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_data) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [28]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

best_vloss = 1_000_000.

for epoch in range(config.get("epoch")):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    for i in tqdm(range(len(val_data)),desc="Validation",total=(config["val_shape"][0]//config["bs"])+1):
        cat_b,con_b,tar = val_data[i]

        cat_b = cat_b.long().to(device)
        con_b = con_b.float().to(device)
        tar = tar.long().to(device)

        voutputs = model(cat_b,con_b)
        outputs = torch.argmax(voutputs,1).float()

        vloss = loss_fn(voutputs,tar)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:


Training:   0%|          | 0/2869 [00:00<?, ?it/s]

torch.Size([0, 1]) torch.Size([0])


Training:   0%|          | 1/2869 [00:02<1:44:37,  2.19s/it]


RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`