In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from dataloader.original import get_dataset
import torch
from torch.utils.data import DataLoader
from architecture.pipe import get_model
from tqdm import tqdm
import os

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
training_data = DataLoader(get_dataset('train', augmentation=True), batch_size=32, shuffle=True, num_workers=8, pin_memory=True)
testing_data = DataLoader(get_dataset('test', augmentation=True), batch_size=32, shuffle=True, num_workers=8, pin_memory=True)

In [5]:
vision_model = get_model()

In [6]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vision_model.parameters(), lr=1e-4)

In [7]:
state = torch.load('./saves/exp_5/model_epoch_2.pth')
vision_model.load_state_dict(state['model_state_dict'])
optimizer.load_state_dict(state['optimizer_state_dict'])
del state

In [8]:
def train_loop(dataloader, testing_data, model, loss_fn, optimizer, epochs=10, save_path='./saves'):
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    all_exps = [int(elm.replace('exp_', '')) if elm != '.ipynb_checkpoints' else -1 for elm in os.listdir(save_path)]

    current_num = max(all_exps) if (len(all_exps) > 0) else 0
    save_path = save_path + '/exp_' + str(current_num + 1)
    os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0
        accuracy = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)
            
            batch_accuracy = ((scores.argmax(-1) == targets.argmax(-1)).int().sum() / len(data)).item()
            accuracy += batch_accuracy

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item(), accuracy=batch_accuracy)
        
        average_training_accuracy = accuracy / len(dataloader.dataset)

        model.eval()
        testing_loss = 0
        testing_accuracy = 0
        for test in tqdm(testing_data):
            X, y = test
            
            X, y = X.to(device), y.to(device)
            
            with torch.no_grad():                
                score = model(X)
                testing_loss =+ loss_fn(score, y)

            testing_accuracy += ((score.argmax(-1) == y.argmax(-1)).int().sum() / len(X)).item()
            
        average_testing_accuracy = accuracy / len(testing_data)

        model.train()

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)} with testing loss of {testing_loss / len(testing_data)}, with traning accuracy {average_testing_accuracy}, and testing accuracy {average_testing_accuracy}")

        # Save the models after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
            'testing_loss': (testing_loss / len(testing_data)).cpu(),
            'average_training_accuracy': average_training_accuracy,
            'average_testing_accuracy': average_testing_accuracy,
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [9]:
train_loop(training_data, testing_data, vision_model, loss_fn, optimizer, epochs=10)

Epoch 1/10


100%|██████████| 856/856 [03:54<00:00,  3.65it/s, accuracy=0, loss=3.21]     
100%|██████████| 163/163 [01:00<00:00,  2.67it/s]


Epoch 1 average loss: 3.41473095578568 with testing loss of 0.02029408887028694, with traning accuracy 0.07440633594167013, and testing accuracy 0.07440633594167013
Epoch 2/10


100%|██████████| 856/856 [05:22<00:00,  2.66it/s, accuracy=0, loss=3.42]     
100%|██████████| 163/163 [00:59<00:00,  2.75it/s]


Epoch 2 average loss: 3.4101071446855493 with testing loss of 0.02189774625003338, with traning accuracy 0.0718941717791411, and testing accuracy 0.0718941717791411
Epoch 3/10


100%|██████████| 856/856 [05:04<00:00,  2.81it/s, accuracy=0.125, loss=3.48] 
100%|██████████| 163/163 [00:58<00:00,  2.81it/s]


Epoch 3 average loss: 3.412141843376873 with testing loss of 0.020703157410025597, with traning accuracy 0.06367013962265172, and testing accuracy 0.06367013962265172
Epoch 4/10


100%|██████████| 856/856 [04:50<00:00,  2.95it/s, accuracy=0, loss=3.47]     
100%|██████████| 163/163 [00:47<00:00,  3.42it/s]


Epoch 4 average loss: 3.404780360861359 with testing loss of 0.020638303831219673, with traning accuracy 0.07116696636186787, and testing accuracy 0.07116696636186787
Epoch 5/10


100%|██████████| 856/856 [03:32<00:00,  4.03it/s, accuracy=0.125, loss=3.45] 
100%|██████████| 163/163 [00:43<00:00,  3.72it/s]


Epoch 5 average loss: 3.4072035982787052 with testing loss of 0.020553115755319595, with traning accuracy 0.06045060289418039, and testing accuracy 0.06045060289418039
Epoch 6/10


  0%|          | 0/856 [00:20<?, ?it/s]


KeyboardInterrupt: 