## Import Library

In [5]:
import torch
from model.ResNet import resnet_model
from model.CBAM import cbam_resnet_model

from data_loader import DataLoaderWrapper
from model.Hyperparameters import Hyperparameters as hp
import matplotlib.pyplot as plt

## Device configuration

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Initialize data loaders

In [7]:
data_loader = DataLoaderWrapper()
train_loader = data_loader.train_loader
val_loader = data_loader.val_loader
test_loader = data_loader.test_loader

Files already downloaded and verified
Number of training samples: 50000
Files already downloaded and verified
Number of test samples: 10000
training Data shape: torch.Size([3, 32, 32])
test Data shape: torch.Size([3, 32, 32])


## Initialize models

In [8]:
resnet = resnet_model()
cbam_resnet = cbam_resnet_model()
# senet = senet_model()

## Training Model

### A. Training Baseline

In [9]:
# Train ResNet
print("Training ResNet Model")
resnet_metrics = resnet.fit(train_loader, val_loader, device)

Training ResNet Model


### B. Training CBAM

In [None]:
# Train CBAM ResNet
print("Training CBAM ResNet Model")
cbam_metrics = cbam_resnet.fit(train_loader, val_loader, device)

### C. Training SE-Net

## Evaluate models on test data


In [None]:
resnet_test_loss, resnet_test_acc = resnet.evaluate(test_loader, device)
cbam_test_loss, cbam_test_acc = cbam_resnet.evaluate(test_loader, device)


## Plot for training Result

In [None]:
# Compare results
# Plot training and validation losses
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(resnet_metrics['train_losses'], label='ResNet Train Loss')
plt.plot(resnet_metrics['val_losses'], label='ResNet Val Loss')
plt.plot(cbam_metrics['train_losses'], label='CBAM ResNet Train Loss')
plt.plot(cbam_metrics['val_losses'], label='CBAM ResNet Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Losses')

# Plot training and validation accuracies
plt.subplot(1, 2, 2)
plt.plot(resnet_metrics['train_accuracies'], label='ResNet Train Acc')
plt.plot(resnet_metrics['val_accuracies'], label='ResNet Val Acc')
plt.plot(cbam_metrics['train_accuracies'], label='CBAM ResNet Train Acc')
plt.plot(cbam_metrics['val_accuracies'], label='CBAM ResNet Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and Validation Accuracies')

plt.tight_layout()
plt.show()

## test accuracy

In [None]:
# Print test accuracies
print(f'ResNet Test Accuracy: {resnet_test_acc:.4f}')
print(f'CBAM ResNet Test Accuracy: {cbam_test_acc:.4f}')