# Installing Dependencies and accessing Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyctm==0.0.13

Collecting pyctm==0.0.13
  Downloading pyctm-0.0.13-py3-none-any.whl.metadata (389 bytes)
Collecting confluent-kafka (from pyctm==0.0.13)
  Downloading confluent_kafka-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (2.3 kB)
Collecting tk (from pyctm==0.0.13)
  Downloading tk-0.1.0-py3-none-any.whl.metadata (693 bytes)
Downloading pyctm-0.0.13-py3-none-any.whl (36 kB)
Downloading confluent_kafka-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m43.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tk-0.1.0-py3-none-any.whl (3.9 kB)
Installing collected packages: tk, confluent-kafka, pyctm
Successfully installed confluent-kafka-2.6.0 pyctm-0.0.13 tk-0.1.0


# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as ff
import torch.optim as optim
import torch.utils.data as data_utils
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torch.nn import Parameter
import torchvision
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.optim.lr_scheduler import StepLR

from torchvision import models
import pandas as pd
import json as json
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm

import math
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, roc_auc_score, roc_curve, auc

from pyctm.representation.sdr_idea_array import SDRIdeaArray

from pyctm.representation.sdr_idea_deserializer import SDRIdeaDeserializer
from pyctm.representation.sdr_idea_serializer import SDRIdeaSerializer
from pyctm.representation.dictionary import Dictionary
from pyctm.representation.idea import Idea

from pyctm.representation.array_dictionary import ArrayDictionary

from pyctm.representation.idea import Idea
from pyctm.representation.sdr_idea_array_serializer import SDRIdeaArraySerializer

from prettytable import PrettyTable

In [None]:
print("\nChecando GPU...")

print("Dispositivo cuda disponível? ", end='')
use_gpu = False
if torch.cuda.is_available() is True:
    device = torch.device("cuda:0")
    print("sim: " + str(device))
    from torch.cuda import get_device_name
    use_gpu = True
    print("GPU:" + str(get_device_name(0)))
else:
    device = torch.device("cpu")
    print("não. Usando CPU.")


Checando GPU...
Dispositivo cuda disponível? sim: cuda:0
GPU:NVIDIA A100-SXM4-40GB


## Utils Functions

In [None]:


import matplotlib.pyplot as plt
import time
from torchvision.utils import make_grid
import torch

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=4, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    plt.figure(figsize=(15, 15))

    # Convert tensor to numpy array
    image_array = image_tensor.detach().cpu().numpy()

    # Choose a colormap (e.g., 'viridis') to represent the values
    cmap = 'viridis'

    # Plot each image in the grid
    for i in range(num_images):
        plt.subplot(nrow, nrow, i + 1)
        plt.imshow(image_array[i][0], cmap=cmap, vmin=image_tensor.min(), vmax=image_tensor.max())  # Adiciona vmin e vmax
        plt.axis('off')

    plt.colorbar()

    if show:
        plt.show()

In [None]:
def weights_init(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

# Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerXLAttention(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super(TransformerXLAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dropout = dropout

        # Initialize query, key, and value linear transformations
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        # Initialize output linear transformation
        self.W_O = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        # Linear transformations
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V)

        # Split heads
        Q = self.split_heads(Q, self.n_head)
        K = self.split_heads(K, self.n_head)
        V = self.split_heads(V, self.n_head)

        # Scale dot product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)

        # Apply mask if provided
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # Weighted sum of values
        attn_output = torch.matmul(attn_probs, V)

        # Combine heads
        attn_output = self.combine_heads(attn_output)

        # Linear transformation for output
        attn_output = self.W_O(attn_output)

        return attn_output

    def split_heads(self, x, n_head):
        batch_size, seq_len, d_model = x.size()
        head_dim = d_model // n_head
        x = x.view(batch_size, seq_len, n_head, head_dim)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        batch_size, n_head, seq_len, head_dim = x.size()
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, seq_len, n_head * head_dim)


class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class ResidualLayer(nn.Module):
    def __init__(self, sublayer, input_dim):
        super(ResidualLayer, self).__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(input_dim)

    def forward(self, x):
        return x + self.sublayer(self.norm(x))

class ConvolutionalEmbeddingLayer1D(nn.Module):
    def __init__(self, input_dim, d_model):
        super(ConvolutionalEmbeddingLayer1D, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x.permute(0, 2, 1)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x.permute(0, 2, 1)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = TransformerXLAttention(d_model, n_head, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = TransformerXLAttention(d_model, n_head, dropout)
        self.cross_attn = TransformerXLAttention(d_model, n_head, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


class PlanningTransformer(nn.Module):
    def __init__(self, vocabulary_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, max_seq_len=20, device='cpu'):
        super(PlanningTransformer, self).__init__()

        self.vocabulary_size = vocabulary_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.device = device

        self.encoder_embedding = nn.Embedding(vocabulary_size, d_model)
        self.decoder_embedding = nn.Embedding(vocabulary_size, d_model)

        # Positional Encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_decoder_layers)])

        self.dropout = nn.Dropout(dropout)

        self.output_layer =nn.Linear(d_model, self.vocabulary_size)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(self.device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


    def create_positional_encoding(self, max_len, d_model):
        # Create a matrix of positional encodings
        positional_encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return nn.Parameter(positional_encoding, requires_grad=False)

    def forward(self, src, tgt):
        batch_size_tgt, seq_len_tgt = tgt.size()
        batch_size, seq_len = src.size()

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.output_layer(dec_output)

        output = output.view(batch_size_tgt, seq_len_tgt, self.vocabulary_size)
        return output

## Model Test

In [None]:
# Hyperparameters
vocabulary_size = 33  # Flattened input size for each of the 20 sequences
d_model = 768  # Size of the transformer layers
nhead = 12  # Number of heads in the multiheadattention models
num_encoder_layers = 2  # Number of sub-encoder-layers in the encoder
num_decoder_layers = 4  # Number of sub-decoder-layers in the decoder
dim_feedforward = 768  # Dimension of the feedforward network model
dropout = 0.3  # Dropout value
learning_rate = 1e-4 # Learning rate
epochs = 100 # Epochs
NGPU = 1 # GPU Number
batch_size = 64
max_seq_len = 626

input = torch.zeros((10, 626)).long()
target = torch.ones((10, 15)).long()

model = PlanningTransformer(vocabulary_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, max_seq_len, device='cpu')
preds = model(input, target)

print(preds.shape)

torch.Size([10, 15, 33])


# Data Treatments

## Dataset Class

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

import torch
import numpy as np

class PlanDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        input = torch.from_numpy(np.asarray(self.dataset["input"].values[index]))
        label = torch.from_numpy(np.asarray(self.dataset["output"].values[index]))

        input = input.long()
        label = label.long()

        return input, label



## Loading Data

In [None]:
!gdown --id 13kkDkMDlTOfFU4kNQrEenvMPcDYC64hr


Downloading...
From (original): https://drive.google.com/uc?id=13kkDkMDlTOfFU4kNQrEenvMPcDYC64hr
From (redirected): https://drive.google.com/uc?id=13kkDkMDlTOfFU4kNQrEenvMPcDYC64hr&confirm=t&uuid=39e007ca-46c0-46a8-b74d-549e50f90531
To: /content/dataPlanSDR35.zip
100% 9.50M/9.50M [00:00<00:00, 86.5MB/s]


In [None]:
!rm -rf dataPlanSDR/
!unzip "/content/drive/MyDrive/data/SDR/dataPlanSDR35.zip" -d .

Archive:  /content/drive/MyDrive/data/SDR/dataPlanSDR35.zip
   creating: ./dataPlanSDR/
  inflating: ./dataPlanSDR/dataPlanSDR_1606.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1609.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1610.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1614.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1617.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1618.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1620.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1623.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1625.json  
  inflating: ./dataPlanSDR/dataPlanSDR_0.json  
  inflating: ./dataPlanSDR/dataPlanSDR_1.json  
  inflating: ./dataPlanSDR/dataPlanSDR_2.json  
  inflating: ./dataPlanSDR/dataPlanSDR_6.json  
  inflating: ./dataPlanSDR/dataPlanSDR_7.json  
  inflating: ./dataPlanSDR/dataPlanSDR_10.json  
  inflating: ./dataPlanSDR/dataPlanSDR_15.json  
  inflating: ./dataPlanSDR/dataPlanSDR_16.json  
  inflating: ./dataPlanSDR/dataPlanSDR_20.json  
  inflating: ./da

In [None]:
df = pd.read_json("/content/dataPlanSDR/dataPlanSDR_0.json")

for i in range(1,2499):
  df = pd.concat([df, pd.read_json("/content/dataPlanSDR/dataPlanSDR_%s.json" % i)])
  print("Loaded File - dataPlanSDR_%s.json" % i)

Loaded File - dataPlanSDR_1.json
Loaded File - dataPlanSDR_2.json
Loaded File - dataPlanSDR_3.json
Loaded File - dataPlanSDR_4.json
Loaded File - dataPlanSDR_5.json
Loaded File - dataPlanSDR_6.json
Loaded File - dataPlanSDR_7.json
Loaded File - dataPlanSDR_8.json
Loaded File - dataPlanSDR_9.json
Loaded File - dataPlanSDR_10.json
Loaded File - dataPlanSDR_11.json
Loaded File - dataPlanSDR_12.json
Loaded File - dataPlanSDR_13.json
Loaded File - dataPlanSDR_14.json
Loaded File - dataPlanSDR_15.json
Loaded File - dataPlanSDR_16.json
Loaded File - dataPlanSDR_17.json
Loaded File - dataPlanSDR_18.json
Loaded File - dataPlanSDR_19.json
Loaded File - dataPlanSDR_20.json
Loaded File - dataPlanSDR_21.json
Loaded File - dataPlanSDR_22.json
Loaded File - dataPlanSDR_23.json
Loaded File - dataPlanSDR_24.json
Loaded File - dataPlanSDR_25.json
Loaded File - dataPlanSDR_26.json
Loaded File - dataPlanSDR_27.json
Loaded File - dataPlanSDR_28.json
Loaded File - dataPlanSDR_29.json
Loaded File - dataPlanS

### Spliting Data - Train and Validation

In [None]:
train_df, validation_df, test_df = \
              np.split(df.sample(frac=1, random_state=42),
                       [int(.7*len(df)), int(.85*len(df))])

train_size = len(train_df)

print("Train Size:" + str(len(train_df)))
print("Validation Size:" + str(len(validation_df)))
print("Test Size:" + str(len(test_df)))

Train Size:174915
Validation Size:37482
Test Size:37482


  return bound(*args, **kwds)


# Trainning

## Hyperparameters

In [None]:
# Hyperparameters
vocabulary_size = len(dictionary_map['words'].keys())
d_model = 768  # Size of the transformer layers
nhead = 12  # Number of heads in the multiheadattention models
num_encoder_layers = 2  # Number of sub-encoder-layers in the encoder
num_decoder_layers = 4  # Number of sub-decoder-layers in the decoder
dim_feedforward = 768  # Dimension of the feedforward network model
dropout = 0.3  # Dropout value
learning_rate = 1e-4 # Learning rate
epochs = 100 # Epochs
NGPU = 1 # GPU Number
batch_size = 64
max_seq_len = 374

## Instanciate data loaders

In [None]:
train_plan_dataset = PlanDataset(train_df)
train_data_loader = DataLoader(train_plan_dataset, batch_size=batch_size, shuffle=True)

validation_plan_dataset = PlanDataset(validation_df)
validation_data_loader = DataLoader(validation_plan_dataset, batch_size=batch_size, shuffle=True)

test_plan_dataset = PlanDataset(test_df)
test_data_loader = DataLoader(test_plan_dataset, batch_size=1, shuffle=False)

## Instanciate model

In [None]:
# Instantiate the model
gen = PlanningTransformer(vocabulary_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, max_seq_len, device).to(device)

In [None]:
if (device.type == 'cuda') and (NGPU > 1):
    gen = nn.DataParallel(gen, list(range(NGPU)))

gen.apply(weights_init)

PlanningTransformer(
  (encoder_embedding): Embedding(33, 768)
  (decoder_embedding): Embedding(33, 768)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-1): 2 x EncoderLayer(
      (self_attn): TransformerXLAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
        (W_O): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
      )
      (feed_forward): PositionWiseFeedForward(
        (fc1): Linear(in_features=768, out_features=768, bias=True)
        (fc2): Linear(in_features=768, out_features=768, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.3, inplace=False)
    )
  )
  (deco

## Optimizers

In [None]:
opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
lr_scheduler_gen = StepLR(opt_gen, step_size=5, gamma=0.9)

## Training Method and Process

In [None]:
loss_train_list = []
v_loss_train_list = []

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
def validate(model, validation_loader, epoch, epochs):

    v_loss_mean = 0
    v_batch_counter = 0

    model.eval()
    v_loop = tqdm(validation_data_loader)

    with torch.no_grad():
      for v_batch_idx, (v_input, v_label) in enumerate(v_loop):

        v_input = v_input.to(device)
        v_label = v_label.to(device)

        v_output = model(v_input, v_label[:,:-1])

        v_loss = criterion(v_output.view(-1, vocabulary_size), v_label[:, 1:].reshape(-1))
        v_loss_mean += v_loss.item()
        v_batch_counter += 1

        v_loop.set_description(f"Epoch [{epoch}/{epochs}] Validation Loss: {v_loss:.10f}")
        v_loop.refresh()

    return v_loss_mean / v_batch_counter

In [None]:
def train(model, optimizer, lr_scheduler, train_data_loader, validation_data_loader, epochs=5, retain_graph=True):
    step = 0
    best_val_loss = 1e10

    # Training
    for epoch in range(epochs):
      clear_output()
      loop = tqdm(train_data_loader)

      loss_mean = 0
      batch_counter = 0

      model.train()
      for batch_idx, (input, label) in enumerate(loop):

          input = input.to(device)
          label = label.to(device)

          output = model(input, label[:,:-1])

          loss = criterion(output.view(-1, vocabulary_size), label[:, 1:].reshape(-1))

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          loss_mean += loss.item()
          batch_counter += 1

          loop.set_description(f"Epoch [{epoch}/{epochs}] Loss: {loss:.10f}")
          loop.refresh()

      loss_mean /= batch_counter

      lr_scheduler_gen.step()

      plt.title("CE Losses During Training/Validation")
      plt.plot(loss_train_list,label="CE")
      plt.plot(v_loss_train_list,label="VCE")
      plt.xlabel("Iterations")
      plt.ylabel("Losses")
      plt.legend()
      plt.show()

      v_loss_mean = validate(model, validation_data_loader, epoch, epochs)

      if v_loss_mean < best_val_loss:
        best_val_loss = v_loss_mean
        best_epoch = epoch
        torch.save(model.state_dict(), "/content/drive/MyDrive/data/models/generators/generator_P_SM_12_768_TRANS_E100_231024.pth")

      step+=1

      loss_train_list.append(loss_mean)
      v_loss_train_list.append(v_loss_mean)

In [None]:
train(gen, opt_gen, lr_scheduler, train_data_loader, validation_data_loader, epochs, retain_graph=False)