In [9]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from sklearn.metrics import precision_recall_fscore_support
import os
import time
from torch.optim.lr_scheduler import StepLR

In [10]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [11]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [12]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

class ADLNet(nn.Module):
    def __init__(self, num_classes=100):
        super(ADLNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((3, 3))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128 * 3 * 3, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), 128 * 3 * 3)
        x = self.classifier(x)
        return x
model = ADLNet().to(device)
# print(model.summary())
print(model)

Using cpu device
ADLNet(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.25, inplace=False)
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0

In [13]:
lr_step_gamma = 0.7
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=1, gamma=lr_step_gamma)

In [14]:
def train(dataloader, model, loss_fn, optimizer, scheduler, prune_percentage):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        # Get the magnitudes of weights and sort them in descending order
        weight_magnitudes = []
        for param in model.parameters():
            weight_magnitudes.extend(torch.abs(param.data).flatten())
        weight_magnitudes = torch.sort(torch.tensor(weight_magnitudes), descending=True).values

        # Set pruning threshold based on specified percentage
        prune_threshold = weight_magnitudes[int(prune_percentage * len(weight_magnitudes))]

        # Prune weights based on the threshold
        for param in model.parameters():
            mask = torch.abs(param.data) >= prune_threshold
            param.data *= mask

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

        # Update the learning rate
        scheduler.step()

In [15]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [16]:
# Fine-tuning after pruning
fine_tune_optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # Use lower learning rate
fine_tune_epochs = 2
prune_percentage = 0.2
for t in range(fine_tune_epochs):
    print(f"Fine-tuning Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, fine_tune_optimizer, scheduler, prune_percentage)
    test(test_dataloader, model, loss_fn)

Fine-tuning Epoch 1
-------------------------------
loss: 4.775121  [   64/60000]
loss: 3.519913  [ 6464/60000]
loss: 1.625332  [12864/60000]
loss: 1.334817  [19264/60000]
loss: 1.026526  [25664/60000]
loss: 1.057887  [32064/60000]
loss: 0.841550  [38464/60000]
loss: 0.919455  [44864/60000]
loss: 0.794893  [51264/60000]
loss: 0.597531  [57664/60000]
Test Error: 
 Accuracy: 75.5%, Avg loss: 0.722145 

Fine-tuning Epoch 2
-------------------------------
loss: 0.694820  [   64/60000]
loss: 0.730169  [ 6464/60000]
loss: 0.488343  [12864/60000]
loss: 0.702612  [19264/60000]
loss: 0.603308  [25664/60000]
loss: 0.671341  [32064/60000]
loss: 0.645819  [38464/60000]
loss: 0.805023  [44864/60000]
loss: 0.687872  [51264/60000]
loss: 0.580785  [57664/60000]
Test Error: 
 Accuracy: 80.6%, Avg loss: 0.547708 



In [17]:
# Main training loop
epochs = 5
start_time = time.time()
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, scheduler, prune_percentage)
    test(test_dataloader, model, loss_fn)
end_time = time.time()
training_time = end_time - start_time
print('Finished Training')
print('Training time: ', training_time, 'seconds')
print("Done!")

Epoch 1
-------------------------------
loss: 0.398902  [   64/60000]
loss: 0.668929  [ 6464/60000]
loss: 0.423881  [12864/60000]
loss: 0.661852  [19264/60000]
loss: 0.472019  [25664/60000]
loss: 0.568383  [32064/60000]
loss: 0.585043  [38464/60000]
loss: 0.647204  [44864/60000]
loss: 0.678222  [51264/60000]
loss: 0.504215  [57664/60000]
Test Error: 
 Accuracy: 80.3%, Avg loss: 0.552233 

Epoch 2
-------------------------------
loss: 0.494767  [   64/60000]
loss: 0.733721  [ 6464/60000]
loss: 0.424727  [12864/60000]
loss: 0.667936  [19264/60000]
loss: 0.608046  [25664/60000]
loss: 0.536186  [32064/60000]
loss: 0.615958  [38464/60000]
loss: 0.785528  [44864/60000]
loss: 0.610756  [51264/60000]
loss: 0.537841  [57664/60000]
Test Error: 
 Accuracy: 80.2%, Avg loss: 0.556307 

Epoch 3
-------------------------------
loss: 0.619932  [   64/60000]
loss: 0.563121  [ 6464/60000]
loss: 0.392936  [12864/60000]
loss: 0.631039  [19264/60000]
loss: 0.574563  [25664/60000]
loss: 0.606908  [32064/600

In [18]:
model.to(device)
def calculate_metrics(model, dataloader):
    correct = 0
    total = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            y_true.extend(labels.tolist())
            y_pred.extend(predicted.tolist())

    accuracy = correct / total
    precision, recall, _, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

    return accuracy, precision, recall

In [19]:
# Calculate metrics
accuracy, precision, recall = calculate_metrics(model, test_dataloader)

print(f'Accuracy: {accuracy}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')

Accuracy: 0.8022
Precision: 0.8114664738759864
Recall: 0.8022


In [20]:
def get_model_size(model):
    torch.save(model.state_dict(), "model.pth")
    size_MB = os.path.getsize("model.pth")/1e6
    return size_MB

def get_inference_time(model, input_shape=(1, 1, 28, 28), repeat=100):
    device = next(model.parameters()).device  # Get the device of the model
    model.eval()
    input_data = torch.randn(input_shape).to(device)  # Move input data to the same device as the model
    start_time = time.time()
    for _ in range(repeat):
        with torch.no_grad():
            _ = model(input_data)
    return (time.time() - start_time) / repeat

# Get model size
print(f'Model size: {get_model_size(model)} MB')

# Get inference time
print(f'Average inference time: {get_inference_time(model)} s')

Model size: 4.169343 MB
Average inference time: 0.003504776954650879 s


In [21]:
model = ADLNet()
model.load_state_dict(torch.load("model.pth"))

<All keys matched successfully>

In [22]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
x = x.unsqueeze(0)
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"
