In [1]:
from EEResnet18 import ResNet, ResidualBlock, data_loader
import numpy as np
import torch
import torch.nn as nn
import torch.backends
import torch.backends.mps
import torch.mps
from tqdm import tqdm

Setting RESNET-18 device as mps


In [2]:
train_loader, valid_loader = data_loader(data_dir='./data',
                                         batch_size=64, data_model="cifar100")

test_loader = data_loader(data_dir='./data',
                              batch_size=64,
                              test=True, data_model="cifar100")

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Setting TEST device as {device}")

device = torch.device(device)

Files already downloaded and verified
Files already downloaded and verified
50000
Files already downloaded and verified
Setting TEST device as mps


In [3]:
num_classes = 10
num_epochs = 20
# batch_size = 16
learning_rate = 0.01

model = ResNet().to(device)
model.requires_grad_ = True
# model = ResNet()
model.make_backbone(ResidualBlock, [2, 2, 2, 2], [1, 3], num_classes)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.001, momentum = 0.9)  
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 0.001)  

# Train the model
# total_step = len(train_loader)
print(len(train_loader) + len(valid_loader))

import gc
total_step = len(train_loader)
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} ================")
    # n_batch = (n_sample - 1) // batch_size + 1
    # model.train()
    total_sample = 0
    total_loss = total_accuracy = 0
    with tqdm(total=64) as dadyNgoy:
        for i, (images, labels) in enumerate(train_loader): 
            # print(i) 
            # Move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            _, predicted = torch.max(outputs[-1].data, 1)
            loss = criterion(outputs[-1], labels)
            total_accuracy += (predicted == labels).sum().item()
            total_loss += loss
            total_sample += labels.size(0)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            del images, labels, outputs
            torch.mps.empty_cache() if torch.backends.mps.is_available() else torch.cuda.empty_cache()
            gc.collect()

            if (i + 1) % 3 == 0 or i + 1 == 64:
                dadyNgoy.set_postfix({
                            'avg_loss': total_loss / (i + 1),
                            'avg_accuracy': total_accuracy / (i + 1),
                            # 'max_abs_gradient': np.max(abs(model.weight.grad))
                        })
                cur_n_batch = i % 3 + 1
                dadyNgoy.update(cur_n_batch)

    print ('Epoch [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, loss.item()))
    accuracy = 100 * total_accuracy / total_sample
    print(f'Epoch {epoch+1}: Accuracy = {accuracy:.2f}%')
        
    model.eval()
                
        # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs[-1].data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
        
        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total)) 

783


703it [01:05, 10.66it/s, avg_loss=tensor(3.6181, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=9.38]                       


Epoch [1/20], Loss: 3.6505
Epoch 1: Accuracy = 14.65%
Accuracy of the network on the 5000 validation images: 21.24 %


703it [01:02, 11.20it/s, avg_loss=tensor(2.9950, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=16.3]                       


Epoch [2/20], Loss: 3.7281
Epoch 2: Accuracy = 25.43%
Accuracy of the network on the 5000 validation images: 28.72 %


703it [01:02, 11.23it/s, avg_loss=tensor(2.6331, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=21]                         


Epoch [3/20], Loss: 1.9296
Epoch 3: Accuracy = 32.79%
Accuracy of the network on the 5000 validation images: 33.0 %


703it [01:03, 11.03it/s, avg_loss=tensor(2.3700, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=24.3]                       


Epoch [4/20], Loss: 2.6384
Epoch 4: Accuracy = 38.04%
Accuracy of the network on the 5000 validation images: 33.96 %


703it [01:02, 11.31it/s, avg_loss=tensor(2.1592, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=27.1]                       


Epoch [5/20], Loss: 2.0811
Epoch 5: Accuracy = 42.35%
Accuracy of the network on the 5000 validation images: 37.36 %


703it [01:02, 11.27it/s, avg_loss=tensor(1.9673, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=30.1]                       


Epoch [6/20], Loss: 1.3132
Epoch 6: Accuracy = 47.02%
Accuracy of the network on the 5000 validation images: 38.56 %


703it [01:02, 11.22it/s, avg_loss=tensor(1.7993, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=32.4]                       


Epoch [7/20], Loss: 0.7203
Epoch 7: Accuracy = 50.71%
Accuracy of the network on the 5000 validation images: 38.54 %


703it [01:02, 11.21it/s, avg_loss=tensor(1.6361, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=34.8]                       


Epoch [8/20], Loss: 1.6903
Epoch 8: Accuracy = 54.42%
Accuracy of the network on the 5000 validation images: 39.24 %


703it [01:02, 11.23it/s, avg_loss=tensor(1.4978, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=36.9]                       


Epoch [9/20], Loss: 1.3979
Epoch 9: Accuracy = 57.63%
Accuracy of the network on the 5000 validation images: 39.38 %


703it [01:03, 10.99it/s, avg_loss=tensor(1.3734, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=38.9]                       


Epoch [10/20], Loss: 2.1365
Epoch 10: Accuracy = 60.72%
Accuracy of the network on the 5000 validation images: 40.46 %


703it [01:02, 11.20it/s, avg_loss=tensor(1.2457, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=40.7]                       


Epoch [11/20], Loss: 1.8185
Epoch 11: Accuracy = 63.66%
Accuracy of the network on the 5000 validation images: 40.48 %


703it [01:02, 11.22it/s, avg_loss=tensor(1.1289, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=42.6]                       


Epoch [12/20], Loss: 0.9880
Epoch 12: Accuracy = 66.63%
Accuracy of the network on the 5000 validation images: 39.78 %


703it [01:02, 11.20it/s, avg_loss=tensor(1.0276, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=44.4]                       


Epoch [13/20], Loss: 0.9280
Epoch 13: Accuracy = 69.29%
Accuracy of the network on the 5000 validation images: 39.08 %


703it [01:02, 11.23it/s, avg_loss=tensor(0.9398, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=45.8]                       


Epoch [14/20], Loss: 2.5828
Epoch 14: Accuracy = 71.54%
Accuracy of the network on the 5000 validation images: 38.98 %


703it [01:03, 11.05it/s, avg_loss=tensor(0.8811, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=46.7]                       


Epoch [15/20], Loss: 1.1778
Epoch 15: Accuracy = 73.03%
Accuracy of the network on the 5000 validation images: 39.28 %


703it [01:03, 11.05it/s, avg_loss=tensor(0.7852, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=48.5]                       


Epoch [16/20], Loss: 1.0696
Epoch 16: Accuracy = 75.76%
Accuracy of the network on the 5000 validation images: 38.54 %


703it [01:03, 11.05it/s, avg_loss=tensor(0.7272, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=49.6]                       


Epoch [17/20], Loss: 1.3890
Epoch 17: Accuracy = 77.42%
Accuracy of the network on the 5000 validation images: 38.06 %


703it [01:03, 11.11it/s, avg_loss=tensor(0.7034, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=49.9]                       


Epoch [18/20], Loss: 1.5064
Epoch 18: Accuracy = 77.88%
Accuracy of the network on the 5000 validation images: 37.54 %


703it [01:03, 11.11it/s, avg_loss=tensor(0.6356, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=51.2]                       


Epoch [19/20], Loss: 1.2744
Epoch 19: Accuracy = 80.03%
Accuracy of the network on the 5000 validation images: 38.26 %


703it [01:03, 11.15it/s, avg_loss=tensor(0.5830, device='mps:0', grad_fn=<DivBackward0>), avg_accuracy=52.3]                       


Epoch [20/20], Loss: 0.1021
Epoch 20: Accuracy = 81.67%
Accuracy of the network on the 5000 validation images: 39.18 %


In [4]:
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs[-1].data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

Accuracy of the network on the 10000 test images: 72.57 %


In [5]:
torch.save(model.state_dict(), "model/RESNET_18-CIFAR_100-SGD-Early_Exits_in_1st_and_3rd_layer-MPS.pth")