In [1]:
import torch
from operator import itemgetter
from torch.utils.data import DataLoader
import random
import numpy as np
import math
import os
from update_utilities import update_utilities_class

# 1 - Data Preparation

## 1.1. Examining the Data

In [2]:
with open('lord-of-the-rings-processed.txt','r',encoding='utf-8') as f:
    text = f.read()

In [3]:
print(f"length of the book - {len(text)} characters")

length of the book - 3729059 characters


In [4]:
print(text[:100])

The Music of the Ainur There was Eru, the One, who in Arda is called lluvatar; and he made first the


## 1.2. Format Data

In [74]:
chars = sorted(list(set(text)))
print(chars)

['\n', ' ', '!', '"', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '®', '—', '‘', '’', '“', '”']


In [75]:
common = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ();:.!?-,"
special = [char for char in chars if char not in list(common)]
print(special)

['\n', ' ', '"', "'", '®', '—', '‘', '’', '“', '”']


In [76]:
text = text.replace("\n"," ")
text = text.replace("  ", " ")
text = text.replace("®", "u")

In [84]:
special_char = list(itemgetter(*[6,7,8,9])(special))
special_char.extend([",",";",":","!","?"])
special_char

['‘', '’', '“', '”', ',', ';', ':', '!', '?']

In [88]:
no_space_after = list(itemgetter(*[0,2])(special_char))
no_space_after

['‘', '“']

In [87]:
no_space_before = list(itemgetter(*set(range(len(special_char)))-set([0,2]))(special_char))
no_space_before

['’', '”', ',', ';', ':', '!', '?']

In [89]:
# replace such as <' sss> to <'sss>
for s in no_space_after:
    text = text.replace(s+" ", s)

# replace such as <s ,> to <s,>
for s in no_space_before:
    text = text.replace(" "+s,s)


In [90]:
# standardize the use of quotation marks
text = text.replace('"',"'")
text = text.replace('‘',"'")
text = text.replace('’',"'")
text = text.replace('“',"'")
text = text.replace('”',"'")

In [91]:
with open("lord-of-the-rings-processed.txt","w") as f:
    f.write(text)

## 1.3. Create Dictionary and Tokenize the Data

**tokenizer**

In [5]:
chars = sorted(list(set(text)))
common = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ();:.!?-,"
special = [char for char in chars if char not in list(common)]
print(special)

[' ', "'", '—']


In [6]:
encode_char = {char:i for i, char in enumerate(chars)}
decode_char = {i:char for i, char in enumerate(chars)}
print(len(encode_char))
vocab_size = len(encode_char)

74


In [7]:
encode = lambda string: [encode_char[s] for s in string]
decode = lambda nums: ''.join([decode_char[n] for n in nums])

In [8]:
encode("This is good")

[40, 54, 55, 65, 0, 55, 65, 0, 53, 61, 61, 50]

In [9]:
decode([8,20,69,44,27])

'0?wXG'

## 1.4. Load data and construct batches + dataloaders

**take percentage of text from each book as validation data**

In [10]:
# update_utilities_class(file_name="general_functions.py",current_path=os.getcwd()).run()


In [10]:
from general_functions import HelperFunctionsClass
h = HelperFunctionsClass()

In [11]:
book1_train, book1_val, end_idx = h.train_test_split(text=text,ending="and an end was come for the Eldar of story and of song.",ratio=0.85,starting_idx=0)

In [12]:
book2_train, book2_val, end_idx = h.train_test_split(text=text, ending="and handed him the tobacco-jar.",ratio=0.85,starting_idx=end_idx)

In [13]:
book3_train, book3_val, end_idx = h.train_test_split(text=text, ending="THE RETURN OF THE KING.",ratio=0.85,starting_idx=end_idx)

In [14]:
book4_train, book4_val, end_idx = h.train_test_split(text=text, ending="was alive but taken by the Enemy.",ratio=0.85,starting_idx=end_idx)

In [15]:
book5_train, book5_val, end_idx2 = h.train_test_split(text=text, ending="I'm back,' he said.",ratio=0.85,starting_idx=end_idx)

In [16]:
train_data = book1_train + book2_train + book3_train + book4_train + book5_train
val_data = book1_val + book2_val + book3_val + book4_val + book5_val

In [17]:
len(train_data + val_data)

3729058

In [18]:
train_data2 = torch.tensor(encode(train_data))
val_data2 = torch.tensor(encode(val_data))

In [19]:
len(train_data2), len(val_data2)

(3170090, 558968)

**dataset and dataloader**

In [20]:
# update_utilities_class(file_name="custom_text_dataset.py",current_path=os.getcwd()).run()

In [21]:
from custom_text_dataset import slideTokenizedTextDataset

In [22]:
block_size = 512

train_dataset = slideTokenizedTextDataset(full_txt = train_data2,
                                                 block_size = block_size)

val_dataset = slideTokenizedTextDataset(full_txt = val_data2,
                                               block_size = block_size)

In [23]:
len(train_dataset), len(val_dataset)

(3169578, 558456)

In [24]:
batch_size = 64
train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,drop_last=True)

In [25]:
val_num_samples = 200000
val_sampler = torch.utils.data.RandomSampler(val_dataset,replacement=False,num_samples=val_num_samples)
val_dataloader = DataLoader(dataset=val_dataset,batch_size=batch_size,sampler=val_sampler,drop_last=True)

In [26]:
s_input, s_output = next(iter(train_dataloader))
print(decode(s_input[2].tolist()))
print(decode(s_output[2].tolist()))

We shall all be caught and killed! I thought you said he was not a friend of theirs.' 'So I did. And don't be silly! You had better go to bed, your wits are sleepy.' The hobbit felt quite crushed, and as there seemed nothing else to do he did go to bed; and while the dwarves were still singing songs he dropped asleep, still puzzling his little head about Beorn, till he dreamed a dream of hundreds of black bears dancing slow heavy dances round and round in the moonlight in the courtyard. Then he woke up when
e shall all be caught and killed! I thought you said he was not a friend of theirs.' 'So I did. And don't be silly! You had better go to bed, your wits are sleepy.' The hobbit felt quite crushed, and as there seemed nothing else to do he did go to bed; and while the dwarves were still singing songs he dropped asleep, still puzzling his little head about Beorn, till he dreamed a dream of hundreds of black bears dancing slow heavy dances round and round in the moonlight in the courtya

In [27]:
len(train_dataloader), len(val_dataloader)

(49524, 3125)

# 2 - Model definition

In [28]:
# update_utilities_class(file_name="Transformer.py",current_path=os.getcwd()).run()

In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [30]:
import Transformer
transformer = Transformer.TransformerClass(vocab_size=vocab_size,emb_dim=512,n_layer=6,num_heads=8,block_size=block_size,
                               dropout_rate_attention=0.1,dropout_rate_ff=0.1,dropout_rate_pos_enc=0.1, 
                               is_decoder = True, ff_multiplier = 4).to(device)


In [31]:
print(round(sum(p.numel() for p in transformer.parameters())/1e6,2), 'M parameters')

18.98 M parameters


'cuda'

In [45]:
# update_utilities_class(file_name="loss_functions.py",current_path=os.getcwd()).run()

File copied, now the file is available to import from the destinated path


In [37]:
from tqdm.auto import tqdm
import math
import torch
from update_utilities import update_utilities_class
import os
import numpy as np
import time
import copy

class train_test_loop_class:
    def __init__(self, model:torch.nn.Module, 
                 train_loader:torch.utils.data.DataLoader,
                 val_loader:torch.utils.data.DataLoader,
                 test_loader, epochs, print_every_n_batch,
                 device,model_name, optimizer,calculate_accuracy,problem_type,overwrite_message,update_loss_fn=False,
                 print_result = True, print_full = True, lr_rate_tuning=False, 
                 clip_batch=False,clip_batch_size=32, lr_start = -5, lr_end = -2):
        
        # initialize variables
        self.train_loader,self.test_loader, self.val_loader = train_loader,test_loader, val_loader
        self.model = model
        self.epochs, self.device, self.model_name, self.optimizer = epochs, device, model_name,optimizer
        self.calculate_accuracy = calculate_accuracy
        self.print_progress = print_every_n_batch
        self.problem_type = problem_type
        self.overwrite_message = overwrite_message
        self.print_result, self.print_full = print_result, print_full
        self.lr_rate_tuning = lr_rate_tuning
        self.clip_batch = clip_batch
        self.clip_batch_size = clip_batch_size
        self.lr_start, self.lr_end = lr_start, lr_end
        
        # create folder to hold model stats
        model_stats_folder = f"{self.model_name} stats"
        if not os.path.exists(model_stats_folder):
            os.makedirs(model_stats_folder)
        self.model_folder = model_stats_folder
        
        # initialize or read past losses
        try:
            past_loss_folder = os.path.join(self.model_folder ,f"{self.model_name} losses")
            past_train_loss_all = list(np.load(past_loss_folder+r"/train_loss_all.npy"))
            past_train_loss = list(np.load(past_loss_folder+r"/train_loss.npy"))
            past_validation_loss = list(np.load(past_loss_folder+r"/validation_loss.npy"))
        except:
            self.losses = {
                "train_loss_all": [],
                "train_loss": [],
                "validation_loss": []
            }
            
            self.accuracy = {
                "train_acc_all": [],
                "train_acc":[],
                "validation_acc": []
            }
        else:
            self.losses = {
                "train_loss_all": past_train_loss_all,
                "train_loss": past_train_loss,
                "validation_loss": past_validation_loss
            }
            if self.calculate_accuracy:
                self.accuracy = {
                    "train_acc_all": list(np.load(past_loss_folder+r"/train_acc_all.npy")),
                    "train_acc": list(np.load(past_loss_folder+r"/train_acc.npy")),
                    "validation_acc": list(np.load(past_loss_folder+r"/validation_acc.npy"))
                }
        
        # update and import loss_functions class
        if update_loss_fn:
            if (print_result & print_full): print("update and import the loss_functions module\n")
            update_file = update_utilities_class(file_name="loss_functions.py",current_path=os.getcwd())
            update_file.run()
        
        from loss_functions import loss_functions_class
        loss_function = loss_functions_class()
        try:
            loss_fn = loss_function.get_loss_fn(self.problem_type)
        except:
            print(f"{self.problem_type} Problem Type is not predefined in the loss_functions_class, need to be added manually")
        self.loss_fn = loss_fn
        
        if (print_result & print_full): print("\nAll initialized, ready to go!")
    
    def lr_tuning(self,dataloader,optimizer,start,end,clip_batch, batch_size):
        assert start-end < 0, "start and end should be negative where start less than end"
        if (self.print_result & self.print_full): print("learning rate tuning\n")
        model = copy.deepcopy(self.model)
        model.train()
        num_step = -(start-end) * 40 + 1
        lre = torch.linspace(start,end,num_step)
        lrs = 10**lre
        lossi = []
        i = 0
        for batch_inputs, batch_labels in dataloader:
            if clip_batch:
                batch_inputs, batch_labels = batch_inputs[:min(batch_size,len(batch_inputs))], batch_labels[:min(batch_size,len(batch_labels))]
            # define the learning rate
            for g in optimizer.param_groups:
                g['lr'] = lrs[i]
            i += 1
            # regular forward pass and backpropogation
            batch_inputs, batch_labels = batch_inputs.to(self.device), batch_labels.to(self.device)
            optimizer.zero_grad()

            model_outputs = model(batch_inputs)
            if "Binary" in self.problem_type:
                model_outputs = model_outputs.squeeze()
                loss = self.loss_fn(model_outputs,batch_labels.float())
            elif len(model_outputs.shape) == 2:
                loss = self.loss_fn(model_outputs,batch_labels)
            else:
                loss = self.loss_fn(torch.flatten(model_outputs,end_dim=1),torch.flatten(batch_labels,end_dim=1))
            loss.backward()
            optimizer.step()
            lossi.append(loss.detach().cpu().item())
            if i == num_step:
                if (self.print_result & self.print_full): print("learning rate tuning finished\n")
                del model
                return lossi

    
    def test(self,mode):
        self.model.eval()
        batch_loss = 0
        batch_acc = 0
        if mode == "validation":
            dataloader = self.val_loader
        else:
            dataloader = self.test_loader
            
        with torch.inference_mode():
            for batch_inputs, batch_labels in dataloader:
                batch_inputs, batch_labels = batch_inputs.to(self.device), batch_labels.to(self.device)
                model_outputs = self.model(batch_inputs)
                if "Binary" in self.problem_type:
                    model_outputs = model_outputs.squeeze()
                    loss = self.loss_fn(model_outputs,batch_labels.float())
                elif len(model_outputs.shape) == 2:
                    loss = self.loss_fn(model_outputs,batch_labels)
                else:
                    loss = self.loss_fn(torch.flatten(model_outputs,end_dim=1),torch.flatten(batch_labels,end_dim=1))
                batch_loss += loss
                if self.calculate_accuracy:
                    if "Binary" in self.problem_type:
                        pred_labels = torch.round(torch.sigmoid(model_outputs))
                    else:
                        pred_labels = model_outputs.argmax(dim=1)
                    acc = torch.eq(pred_labels, batch_labels).sum().item()/len(batch_labels)
                    batch_acc += acc
            avg_loss = batch_loss / len(dataloader)
            avg_acc = batch_acc / len(dataloader) * 100
            if mode != "validation":
                message_file_path = os.path.join(self.model_folder, f"{self.model_name} - Training Information.txt")
                f = open(message_file_path,"a")
                f.write("\n\nTesting Information\n"+"-"*80)
                m = f"Average per-Batch Test Loss: {avg_loss:.4f}"
                print(m)
                f.write("\n"+m)
                if self.calculate_accuracy:
                    m = f"Average per-Batch Test Accuracy: {avg_acc:.2f}%"
                    print(m)
                    f.write("\n"+m)
                f.close()
            return avg_loss, avg_acc
                    
        
    def train(self):
        try:
            lowest_val_loss = np.load(os.path.join(self.model_folder,f"{self.model_name}_lowest_val_loss.npy")).item()
        except:
            lowest_val_loss = 1000000000
        
        start = time.time()
        total_time = 0
        # write the output to a file
        message_file_path = os.path.join(self.model_folder, f"{self.model_name} - Training Information.txt")
        if self.overwrite_message:
            f = open(message_file_path,"w")
        else:
            f = open(message_file_path,"a")
        
        if self.overwrite_message:
            m = f"Basic Specs\n----------------------------------------------------"
            if (self.print_result & self.print_full): print(m)
            f.write("\n"+m)
            sample_inputs, _ = next(iter(self.train_loader))
            m = f"Input Size: {sample_inputs.shape}\n"
            if (self.print_result & self.print_full): print(m)
            f.write("\n"+m)
            m = "\nModel Specs: \n"
            if (self.print_result & self.print_full): print(m)
            f.write("\n"+m)
            if (self.print_result & self.print_full): print(self.model)
            print(self.model,file=f)
            m = "\n\n"
            if (self.print_result & self.print_full): print(m)
            f.write("\n"+m)
        
        f.write("\n\nTraining Information\n" + "-"*80)

        
        # initializing
        num_steps = self.epochs * len(self.train_loader)
        progress_bar = tqdm(range(num_steps))
        print_progress_cycle = 0 # this keeps track of the current number of print_progress cycle
        total_print_progress_cycle = math.ceil(num_steps/self.print_progress)
        

        
        # print initial message
        m = f"Training Begin\n----------------------------------------------------"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        m = f"There are {self.epochs} epochs, and for each epoch, there are {len(self.train_loader)} batches of training data"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        m = f"Total Training Steps: {num_steps}"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        m = f"Total Displaying Information: {total_print_progress_cycle}"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        m = f"Optimizer name - {self.optimizer.__class__.__name__} learning rate: {self.optimizer.param_groups[-1]['lr']}"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        m = f"lowest_val_loss started with {lowest_val_loss}\n"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        
        # initializing
        step = 0
        batch_loss = 0
        batch_acc = 0
        
        # create directory for the model weights
        folder_name = self.model_name + " weights"
        folder_path = os.path.join(self.model_folder,folder_name)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        
        if self.lr_rate_tuning:
            lossi = self.lr_tuning(self.train_loader,self.optimizer,self.lr_start,self.lr_end,self.clip_batch,self.clip_batch_size)
            return lossi

        # training loop
        for e in range(self.epochs):
            for batch_inputs, batch_labels in self.train_loader:
                step += 1
                batch_inputs, batch_labels = batch_inputs.to(self.device), batch_labels.to(self.device)
                self.model.train()
                self.optimizer.zero_grad()
                
                # forward pass
                model_outputs = self.model(batch_inputs)
                
                # loss calculation, backpropogation and update model parameters
                if "Binary" in self.problem_type:
                    model_outputs = model_outputs.squeeze()
                    loss = self.loss_fn(model_outputs,batch_labels.float())
                elif len(model_outputs.shape) == 2:
                    loss = self.loss_fn(model_outputs,batch_labels)
                else:
                    loss = self.loss_fn(torch.flatten(model_outputs,end_dim=1),torch.flatten(batch_labels,end_dim=1))
                loss.backward()
                self.optimizer.step()
                
                # append loss
                self.losses["train_loss_all"].append(loss.detach().cpu().item())
                batch_loss += loss
                
                # if require accuracy calculation
                if self.calculate_accuracy:
                    if "Binary" in self.problem_type:
                        pred_labels = torch.round(torch.sigmoid(model_outputs))
                    else:
                        pred_labels = model_outputs.argmax(dim=1)
                    acc = torch.eq(pred_labels,batch_labels).sum().item()/len(batch_labels)
                    self.accuracy["train_acc_all"].append(acc)
                    batch_acc += acc
                
                # validate and print progress if reach the print_progress or at the last step
                if (step % self.print_progress == 0) | (step == num_steps):
                    print_progress_cycle += 1
                    if (step != num_steps) | (num_steps % self.print_progress == 0):
                        batch_count = self.print_progress
                    else:
                        batch_count = num_steps % self.print_progress
                    
                    # print message
                    m = f"\n\nMessage: {print_progress_cycle} - "\
                          +f"Progress Summary - {batch_count} batches\n--------------------------------"
                    if (self.print_result & self.print_full): print(m)
                    f.write("\n"+m)
                    m = f"Epoch: {e+1} / {self.epochs} || Batch: {step} / {num_steps} || " \
                          + f"Print Cycle: {print_progress_cycle} / {total_print_progress_cycle}"
                    if (self.print_result & self.print_full): print(m)
                    f.write("\n"+m)
                    
                    validation_loss, validation_acc = self.test(mode="validation")
                    self.losses["validation_loss"].append(validation_loss.detach().cpu().item())
                    avg_batch_loss = batch_loss/batch_count
                    self.losses["train_loss"].append(avg_batch_loss.detach().cpu().item())
                    
                    # print message
                    m = f"Average per-Batch Training Loss: {avg_batch_loss:.4f} || " \
                          + f"Average per-Batch Validation Loss: {validation_loss:.4f}"
                    if (self.print_result & (not self.print_full)): 
                        print(f"Batch: {step} / {num_steps} || " + m)
                    if (self.print_result & self.print_full): print(m)
                    f.write("\n"+m)
                    
                    batch_loss = 0
                    
                    # if accuracy need be to calculated
                    if self.calculate_accuracy:
                        self.accuracy["validation_acc"].append(validation_acc)
                        avg_batch_acc = batch_acc/batch_count * 100
                        self.accuracy["train_acc"].append(avg_batch_acc)
                        
                        # print message
                        m = f"Average per-Batch Training Accuracy: {avg_batch_acc:.2f}% || " \
                              + f"Average per-Batch Validation Accuracy: {validation_acc:.2f}%"
                        print(m)
                        if (self.print_result & (not self.print_full)): 
                            print()
                        f.write("\n"+m)
                        
                        batch_acc = 0
                    
                    # calculate model improvement
                    if len(self.losses["train_loss"]) > 1:
                        idx = len(self.losses["train_loss"]) - 1
                        train_loss_perc_decrease = -(self.losses["train_loss"][idx]-self.losses["train_loss"][idx-1]) \
                                                    / self.losses["train_loss"][idx-1] * 100
                        val_loss_perc_decrease = -(self.losses["validation_loss"][idx]-self.losses["validation_loss"][idx-1]) \
                                                 / self.losses["validation_loss"][idx-1] * 100
                        
                        # print message
                        m = "\nModel Improvement\n--------------------------------"
                        if (self.print_result & self.print_full): print(m)
                        f.write("\n"+m)
                        m = f"Average per-Batch Training Loss has decreased by {train_loss_perc_decrease:.2f}%"
                        if (self.print_result & self.print_full): print(m)
                        f.write("\n"+m)
                        m = f"Average per-Batch Validation Loss has decreased by {val_loss_perc_decrease:.2f}%\n"
                        if (self.print_result & self.print_full): print(m)
                        f.write("\n"+m)
                        
                        # if validation loss is the lowest, save the model as the best model weights
                        if validation_loss.cpu() < lowest_val_loss:
                            save_path = folder_path + r"/"+self.model_name+"_best.pth"
                            m = f"Val Loss decreased from {lowest_val_loss:4f} to {validation_loss.cpu():4f} - Saving the Best Model\n"
                            if (self.print_result & self.print_full): print(m)
                            f.write("\n"+m+"\n")
                            torch.save(self.model.state_dict(),save_path)
                            lowest_val_loss = validation_loss.cpu()
                            np.save(os.path.join(self.model_folder,f"{self.model_name}_lowest_val_loss.npy"),lowest_val_loss)
                    end = time.time()
                    time_spent = np.round((end-start)/60,2)
                    total_time += time_spent
                    unit = "minutes"
                    if time_spent > 60:
                        time_spent = np.round(time_spent/60,2)
                        unit = "hours"
                    m = f"This printing cycle took {time_spent} {unit}\n"
                    if (self.print_result & self.print_full): print(m)
                    f.write("\n"+m)
                    start = time.time()
                    
                
                # outside validation and printing
                
                # update progress bar
                progress_bar.update(1)
            
            # outside dataloader iteration
            
        # outside epoch for loop
        save_path = folder_path + r"/"+self.model_name+"_last.pth"
        m = "Saving the Last Model\n"
        if (self.print_result & self.print_full): print(m)
        f.write("\n"+m)
        torch.save(self.model.state_dict(),save_path)
        
        
        # save losses/accuracies
        loss_folder_name = self.model_name + " losses"
        loss_folder_path = os.path.join(self.model_folder,loss_folder_name)
        if not os.path.exists(loss_folder_path):
            os.makedirs(loss_folder_path)
        for key in self.losses:
            np.save(os.path.join(loss_folder_path,key+".npy"),self.losses[key])
        if self.calculate_accuracy:
            for key in self.accuracy:
                np.save(os.path.join(loss_folder_path,key+".npy"),self.accuracy[key])
        
        print("\n All Done\n")
        time_spent = np.round(total_time/60,2)
        m = f"Overall training took {time_spent} hours\n"
        print(m)
        f.write("\n\n"+m+"-"*80+"\n\n\n\n")
        f.close()

    # outside train function



In [38]:
optimizer = torch.optim.AdamW(transformer.parameters(),lr=1e-3)

train_loop = train_test_loop_class(model=transformer,train_loader=train_dataloader,val_loader=val_dataloader,test_loader=None, epochs=1, print_every_n_batch=400,
                                   device=device,model_name="test",optimizer=optimizer,calculate_accuracy=False,overwrite_message=True, problem_type = "Multiclass Classification",
                                   update_loss_fn=False, print_result = True, print_full = False, lr_rate_tuning=False,clip_batch=False,clip_batch_size=20,lr_start=-5,lr_end=-2)

In [39]:
train_loop.train()

  0%|          | 0/49524 [00:00<?, ?it/s]

Batch: 400 / 49524 || Average per-Batch Training Loss: 2.4257 || Average per-Batch Validation Loss: 2.2034
Batch: 800 / 49524 || Average per-Batch Training Loss: 1.7675 || Average per-Batch Validation Loss: 1.3902


KeyboardInterrupt: 