# Neural Network Modeling with MULTFS

#### Overview:
This notebook aims to provided an example of training a PyTorch Model on a premade task dataset
- Text-embedder: ```all-mpnet-base-v2 (pretrained Sentence Transformer)```
- Image-embedder: ```vit_b_16 (pretrained Vision Transformer)``` 
- Decoder/Classifier: Transformer Decoder only Model trained on both text and image embeddings to out put action class


#### Datasets/Training:
- There are two variations of training examples in this notebook:
    - Using a pre-saved dataset generated
        - *See the ..._dataset_gen.ipynb notebooks for how to generate and save a dataset*
    - OR, random task generation per training iteration



### Imports

In [1]:
import numpy as np
import os
from PIL import Image
import json
import math
import copy
from sklearn.model_selection import train_test_split

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from tqdm.notebook import tqdm

from cognitive.task_bank import CompareLocTemporal
from cognitive import task_generator as tg
from cognitive import constants as const
from cognitive import stim_generator as sg
from cognitive import info_generator as ig
import random



  from .autonotebook import tqdm as notebook_tqdm
2023-09-11 11:03:59.787680: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
DIR = 'datasets/trials_reg'  # Dataset Directory
STIM_DIR =  'data/MULTIF_5_stim/MULTIF_5_stim'   # '/Users/lucasgomez/Desktop/Neuro/Bashvian/COGEnv/COG_v3_shapenet-main/data/MULTIF_5_stim/MULTIF_5_stim' # stimulus set
BATCH_SIZE = 1
MAX_FRAMES = 6  # the max possible frames across tasks
VIT_OUT_DIM = 1000 # vision transformer output dimension
LM_OUT_DIM = 768 # language model output dimension

const.DATA = const.Data(dir_path=STIM_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

Stimuli Directory:  data/MULTIF_5_stim/MULTIF_5_stim
OrderedDict([((0, 0), [(0.0, 0.5), (0.0, 0.5)]), ((0, 1), [(0.0, 0.5), (0.5, 1.0)]), ((1, 0), [(0.5, 1.0), (0.0, 0.5)]), ((1, 1), [(0.5, 1.0), (0.5, 1.0)])])


# Datasets
- Read in the pregenerated task trials organized into frames, instructions, and correct actions. 

In [3]:
frames = []
infos = []

for trial_fp in os.listdir(DIR):
    trial_fp = os.path.join(DIR, trial_fp)
    imgs = []
    for fp in os.listdir(trial_fp):
        fp = os.path.join(trial_fp, fp)
        if fp[-4:] == '.png':
            imgs.append(np.rollaxis(np.array(Image.open(fp), dtype=np.float32),2,0))
        else:
            infos.append(json.load(open(fp)))
    if len(imgs) > MAX_FRAMES:
        raise Exception(trial_fp + " contains more frames than the set maximum (MAX_FRAMES) !!!")
    frames.append(np.array(imgs))



In [4]:
instructions = [x['instruction'] for x in infos]
raw_target_actions = [x['answers'] for x in infos]

### Encode Target Actions
- Encodes the target actions into one hot encoding vectors corresponding to the actions ```true, false, and null```

In [5]:
action_map = {'true': 0, 'false': 1, 'null': 2}

target_actions = []

for actions in raw_target_actions:
    encoded = []
    for action in actions:
        encoded.append(action_map[action])
    target_actions.append(encoded)


## Instruction Dataset

In [6]:
class InstructionsDataset(Dataset):
  """
    Pytorch Dataset class to load the Instructions Data

    Data members:
      instructions: list of instructions
      n_ins: number of instructions in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, x):

    self.instructions = x

    self.n_ins = len(self.instructions)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_ins

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return self.instructions[idx]

In [7]:
class InstructionsCollator(object):
  """
    Data Collator used for GPT2 in a classificaiton tasks

    Args:
      use_tokenizer :
        Transformer type tokenizer used to process raw text into numbers.

    Data members:
      use_tokenizer: Tokenizer to be used inside the class.

    Member functions:
      __init__: ctor
      __call__: tokenize input

    """

  def __init__(self, use_tokenizer):

    self.use_tokenizer = use_tokenizer

    return

  def __call__(self, instructions):
    """
        Tokenizes input
    """

    # Call tokenizer
    inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt')


    return inputs


In [8]:
# Pretrained Language Model and Tokenizer 
lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Create data collator to encode text and labels into numbers.
InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# Create pytorch dataset for instructions
ins_train_dataset = InstructionsDataset(instructions)

# Move pytorch dataset into dataloader 
ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [9]:
class FramesDataset(Dataset):
  """
    Pytorch Dataset class to load the Frame Data

    Data members:
      frames``ist of frames
      n_imgs: number of iamges in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_imgs
      __getitem__: returns an frame
  """

  def __init__(self, x):

    self.frames = x

    self.n_imgs = len(self.frames)

    return

  def __len__(self):
    """
      Returns number of frames in the Dataset
    """

    return self.n_imgs

  def __getitem__(self, idx):
    """
      Given an index return a frame
    """

    return torch.tensor(self.frames[idx])

In [10]:
# Pretrained Vision Transformer
vit_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Create pytorch dataset for instructions
frames_train_dataset = FramesDataset(frames)

# Move pytorch dataset into dataloader 
frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Language Encoder

### Language Embedder

In [11]:
def lm_embedder(instruction, encoder):
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)

    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Position Embeddings

### Image Embedder
- We pad frames based on max possible

In [12]:
def img_embedder(frames, encoder):
    with torch.no_grad():
        vit_out = encoder(torch.tensor(frames))

    npads = MAX_FRAMES-len(vit_out)
    pad = torch.ones((npads, vit_out.shape[1]))
    vit_out = torch.cat((vit_out, pad))

    return vit_out

# Create Embeddings

In [13]:
lm_embeddings = []
img_embeddings = []

for i,f in zip(ins_train_dataloader,frames_train_dataloader):
   f = f[0]
   lm_embeddings.append(lm_embedder(i, lm_encoder))
   img_embeddings.append(img_embedder(f, vit_encoder))

  vit_out = encoder(torch.tensor(frames))


# Embeddings Dataset

In [14]:
class EmbeddingsDataset(Dataset):
  """
    Pytorch Dataset class to load the embedded data

    Data members:
      lm_embeddings: list of language model embeddings
      img_embeddings: list of language model embeddings
      n_embs: number of embeddings in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, lm_embeddings, img_embeddings, actions):

    self.lm_embeddings = lm_embeddings
    self.img_embeddings = img_embeddings
    self.actions = actions

    self.n_embs = len(self.lm_embeddings)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_embs

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    # Map One Hot to
    # actions = [np.where(action == 1)[0][0] for action in self.actions[idx]]
    

    return {'instruction':self.lm_embeddings[idx], 'frames':self.img_embeddings[idx], 'actions':torch.tensor(self.actions[idx], dtype=torch.float32)}

In [15]:
# Create pytorch dataset for train, val, test data

train_split_lm, test_split_lm, train_split_img, test_split_img, train_split_ta, test_split_ta = train_test_split(lm_embeddings, img_embeddings, target_actions,test_size=0.1, random_state=1)

train_split_lm, val_split_lm, train_split_img, val_split_img, train_split_ta, val_split_ta  = train_test_split(train_split_lm, train_split_img, train_split_ta, test_size=len(test_split_lm)/len(train_split_lm), random_state=1) 


print(len(train_split_lm))
print(len(test_split_lm))
print(len(val_split_lm))


train_dataset = EmbeddingsDataset(train_split_lm, train_split_img, train_split_ta)
val_dataset = EmbeddingsDataset(val_split_lm, val_split_img, val_split_ta)
test_dataset  = EmbeddingsDataset(test_split_lm, test_split_img, test_split_ta)

# Move pytorch datasets into dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

800
100
100


# Action Decoder

In [16]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CausalMatchTransformer(nn.Module):
    """
    Pytorch based transformer decoder model
    """

    # Initialize Model with Params
    def __init__(self, nframes=MAX_FRAMES, blocks=3, nhead=5, emb_dim=VIT_OUT_DIM, classes=3, device=device):
        super().__init__()

        # Device
        self.device = device

        # Embedding Dimension
        self.emb_dim = emb_dim

        # Number of frames
        self.nframes = nframes

        # Frame Position Embedder Layer
        self.pos_emb = nn.Parameter(torch.Tensor(nframes,emb_dim)).to(device)
        torch.nn.init.xavier_uniform_(
           self.pos_emb,
           gain=torch.nn.init.calculate_gain("linear"))

        # Instruction Dim Projection Layer
        self.lm_linear_layer = nn.Linear(LM_OUT_DIM, emb_dim).to(device)

        # Decoder Layers
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, batch_first=True).to(device)
        self.decoder_layers = _get_clones(self.decoder_layer, blocks)
        
        # Decoder
        self.decoder = nn.TransformerDecoder(self.decoder_layers, num_layers=blocks).to(device)

        # Action classifier
        self.classifier = nn.Linear(emb_dim, classes).to(device)

    # Function for forward pass
    def forward(self, instruction, frames, mask, padding_mask):

        # Project instruction embedding
        instruction = self.lm_linear_layer(instruction)

        # Add the frame position embedding
        for i in range(len(frames)):
            frames[i] += self.pos_emb[i,:]

        # Apply each Decoder Layer (block)
        for layer in self.decoder_layers:
            frames = layer(frames, instruction, tgt_mask=mask, tgt_key_padding_mask=padding_mask, ) 

        # Pass through linear layer for classification
        output = self.classifier(frames)

        return output
 
# Creates a list of torch duplicate torch modules
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

# Creates a square Sequential/Causal mask of size sz
def generate_causal_mask(sz: int) -> Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

# Generates a padding masks for each sequence in a batch
def generate_pad_mask(batch):

    pad_tensor = torch.ones((batch.shape[2])).to(device)

    mask = np.zeros((batch.shape[0],batch.shape[1]))

    for s in range(0, batch.shape[0]):
        for v in range(0, batch[s].shape[0]):
            new_s = torch.all(batch[s][v] == pad_tensor)
            mask[s][v] = new_s

    return torch.tensor(mask).bool().to(device)

In [17]:
VIT_OUT_DIM

1000

In [18]:
model = CausalMatchTransformer(nframes=MAX_FRAMES,
                               blocks=3,
                               nhead=8,
                               emb_dim=VIT_OUT_DIM,
                               classes=len(action_map.keys()),
                               device=device).float().to(device)

# Training

## Pre-generated

In [19]:
# Training configurations
epochs = 7

criterion = nn.CrossEntropyLoss()

lr = 1e-5
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

mask = generate_causal_mask(MAX_FRAMES).to(device)

# torch.manual_seed(0)

In [20]:
# Training and validation loop

# Store the average loss after each epoch
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}

print("starting")
for epoch in range(epochs):
    print(f"epoch={epoch}")

    # Epoch stat trackers
    epoch_loss = 0
    epoch_correct = 0
    epoch_count = 0
    for idx, batch in enumerate(iter(train_dataloader)):

        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions']

        # print(frames.shape)
        # print(instruction.shape)
        # print(len(targets))

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        predictions = predictions[:,pad_indexes]

        # Get Loss by permuting the predictions into correct shape (batch_size, n_classes, seq_len)
        loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track stats
        # torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None)
        correct = predictions.argmax(dim=-1) == targets
        
        epoch_correct += correct.sum().item()
        epoch_count += correct.size(1)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    # Validate on validation set every 5 epochs
    if (epoch+1) % 2 == 0 or epoch == epochs:
        # Turn off gradient calcs
        with torch.no_grad():
            val_epoch_loss = 0
            val_epoch_correct = 0
            val_epoch_count = 0

            for idx, batch in enumerate(iter(val_dataloader)):
                # Inputs and Targets
                instruction = batch['instruction']
                frames = batch['frames']
                targets = batch['actions']

                # Frame Padding
                padding_mask = generate_pad_mask(batch=frames)
                pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

                # Get predictions
                predictions = model(instruction, frames, mask, padding_mask)
                predictions = predictions[:,pad_indexes]

                # Get Loss
                loss = criterion(predictions.permute(0, 2, 1), targets.long())


                correct = predictions.argmax(dim=-1) == targets

                val_epoch_correct += correct.sum().item()
                val_epoch_count += correct.size(1)
                val_epoch_loss += loss.item()

        # Track loss and acc ever 5 epochs
        avg_train_loss = epoch_loss / len(train_dataloader)
        avg_val_loss = val_epoch_loss / len(val_dataloader)

    
        all_loss['val_loss'].append(avg_val_loss)
        all_acc['val_acc'].append(val_epoch_correct / val_epoch_count)

        all_acc['train_acc'].append(epoch_correct / epoch_count)
        all_loss['train_loss'].append(avg_train_loss)

        print(all_acc)

starting
epoch=0




epoch=1
{'train_acc': [0.8289186662511686], 'val_acc': [0.8672985781990521]}
epoch=2
epoch=3
{'train_acc': [0.8289186662511686, 0.8837644125895918], 'val_acc': [0.8672985781990521, 0.9004739336492891]}
epoch=4
epoch=5
{'train_acc': [0.8289186662511686, 0.8837644125895918, 0.9055780617014646], 'val_acc': [0.8672985781990521, 0.9004739336492891, 0.8957345971563981]}
epoch=6


In [21]:
all_acc

{'train_acc': [0.8289186662511686, 0.8837644125895918, 0.9055780617014646],
 'val_acc': [0.8672985781990521, 0.9004739336492891, 0.8957345971563981]}

## Evaluating

In [22]:
with torch.no_grad():
    test_correct = 0
    test_count = 0

    for idx, batch in enumerate(iter(test_dataloader)):
        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions']

        print(targets)

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)

        predictions = predictions[:,pad_indexes]

        print(predictions.argmax(dim=-1))

        correct = predictions.argmax(dim=-1) == targets

        print(correct)

        test_correct += correct.sum().item()
        test_count += correct.size(1)

print("Test Accuracy: ", round(test_correct/test_count,4))

tensor([[2., 2., 2., 2., 1.]])
tensor([[2, 2, 2, 2, 0]])
tensor([[ True,  True,  True,  True, False]])
tensor([[2., 2., 2., 1.]])
tensor([[2, 2, 2, 1]])
tensor([[True, True, True, True]])
tensor([[2., 2., 2., 2., 2., 1.]])
tensor([[2, 2, 2, 2, 2, 1]])
tensor([[True, True, True, True, True, True]])
tensor([[2., 1.]])
tensor([[2, 1]])
tensor([[True, True]])
tensor([[2., 2., 0.]])
tensor([[2, 2, 0]])
tensor([[True, True, True]])
tensor([[2., 0.]])
tensor([[2, 1]])
tensor([[ True, False]])
tensor([[2., 2., 2., 2., 2., 0.]])
tensor([[2, 2, 2, 2, 2, 1]])
tensor([[ True,  True,  True,  True,  True, False]])
tensor([[2., 2., 2., 2., 0.]])
tensor([[2, 2, 2, 2, 0]])
tensor([[True, True, True, True, True]])
tensor([[2., 1.]])
tensor([[2, 1]])
tensor([[True, True]])
tensor([[2., 2., 0.]])
tensor([[2, 2, 1]])
tensor([[ True,  True, False]])
tensor([[2., 2., 1.]])
tensor([[2, 2, 1]])
tensor([[True, True, True]])
tensor([[2., 2., 0.]])
tensor([[2, 2, 0]])
tensor([[True, True, True]])
tensor([[2., 2.,

## Random Task Generation

- Here we generate each task randomly for every training iteration

In [23]:
# Training configurations
epochs = 10

criterion = nn.CrossEntropyLoss()

lr = 1e-3
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

mask = generate_causal_mask(MAX_FRAMES).to(device)

# torch.manual_seed(0)

In [24]:
# Generates a random location comparison task
def gen_comp_loc(lm_embedder, img_embedder):
    # f1 = random.randint(0, 4)
    # f2 = random.randint(f1+1,5)

    # task = CompareLocTemporal(whens=['last'+str(f1),'last'+str(f2)])

    task = CompareLocTemporal(whens=['last0','last1'])

    frame_info = ig.FrameInfo(task, task.generate_objset())
    compo_info = ig.TaskInfoCompo(task, frame_info)
    objset = compo_info.frame_info.objset

    frames = []
    for i, (epoch, frame) in enumerate(zip(sg.render(objset, 224), compo_info.frame_info)):
        if not any('ending' in description for description in frame.description):
            sg.add_fixation_cue(epoch)
        img = np.rollaxis(np.array(Image.fromarray(epoch, 'RGB'), dtype=np.float32),2,0)
        frames.append(img)

    frames = torch.tensor(frames)

    frames = img_embedder(frames, vit_encoder).unsqueeze(0)

    _, compo_example, _ = compo_info.get_examples()

    instruction = compo_example['instruction']

    instruction = tokenizer(instruction, padding=True, truncation=True, return_tensors='pt')

    instruction = lm_embedder(instruction, lm_encoder).unsqueeze(0)

    actions = compo_example['answers']

    action_map = {'true': 0, 'false': 1, 'null': 2}

    target_actions = []
    for action in actions:
        target_actions.append(action_map[action])

    return instruction, frames, torch.tensor(target_actions, dtype=torch.float32).unsqueeze(0)

In [25]:
# Training and validation loop using random generation
n_tasks = 0

# Store the average loss after each epoch
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}

print("starting")
for epoch in range(epochs):
    print(f"epoch={epoch}")

    # Epoch stat trackers
    epoch_loss = 0
    epoch_correct = 0
    epoch_count = 0

    for i in range(n_tasks):

        # Inputs and Targets
        instruction, frames, targets = gen_comp_loc(lm_embedder, img_embedder)

        # print(frames.shape)
        # print(instruction.shape)
        # print(len(targets))

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        predictions = predictions[:,pad_indexes]

        print(predictions.shape)
        print(predictions.permute(0, 2, 1).shape)
        print(targets.shape)

        # Get Loss by permuting the predictions into correct shape (batch_size, n_classes, seq_len)
        loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track stats
        # torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None)
        correct = predictions.argmax(dim=-1) == targets
        acc = correct.sum().item() / correct.size(0)
        epoch_correct += correct.sum().item()
        epoch_count += correct.size(0)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    # Validate on validation set every 5 epochs
    if (epoch+1) % 2 == 0:
        # Turn off gradient calcs
        with torch.no_grad():
            val_epoch_loss = 0
            val_epoch_correct = 0
            val_epoch_count = 0

            for idx, batch in enumerate(iter(val_dataloader)):
                # Inputs and Targets
                instruction = batch['instruction']
                frames = batch['frames']
                targets = batch['actions']
                
                # Frame Padding
                padding_mask = generate_pad_mask(batch=frames)
                pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

                # Get predictions
                predictions = model(instruction, frames, mask, padding_mask)
                predictions = predictions[:,pad_indexes]

                # Get Loss
                loss = criterion(predictions.permute(0, 2, 1), targets.long())

                correct = predictions.argmax(dim=-1) == targets
                acc = correct.sum().item() / correct.size(0)

                val_epoch_correct += correct.sum().item()
                val_epoch_count += correct.size(0)
                val_epoch_loss += loss.item()

        # Track loss and acc ever 5 epochs
        avg_train_loss = epoch_loss / len(train_dataloader)
        avg_val_loss = val_epoch_loss / len(val_dataloader)


        all_loss['val_loss'].append(avg_val_loss)
        all_acc['val_acc'].append(val_epoch_correct / val_epoch_count)

        all_acc['train_acc'].append(epoch_correct / epoch_count)
        all_loss['train_loss'].append(avg_train_loss)

starting
epoch=0
epoch=1


ZeroDivisionError: division by zero