In [41]:
import numpy as np
import cv2 as cv
from sklearn.mixture import GaussianMixture as GMM
from collections import defaultdict
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.ops import roi_align
from pytorchcv.models.common import SEBlock
import random
import itertools

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.cuda.set_per_process_memory_fraction(.9, device=device)
else:
    device = torch.device('cpu')         

print(device)

class CopySingleChannels(object):
    ''' Does nothing to the image if it has three channels. Creates 3 copies if it has one channel. Returns assertion error if neither.'''
    
    def __call__(self, image):
        assert (image.shape[0] == 1 or image.shape[0] == 3)

        if image.shape[0] == 1:       
            return image.repeat(3, 1, 1)
        else:
            return image

transform = transforms.Compose([
    torchvision.transforms.RandomAffine(degrees=15, translate= (.15, .15), scale = (0.85, 1.15)),
    transforms.Resize((224,224)), #Resize data to be 224x224.
    transforms.ToTensor(),
    transforms.Lambda(CopySingleChannels()) #transforms both RGB and grayscale to same tensor size by stacking copies of grayscales along first axis.
    
    ])


# Get the data.
full_dataset = torchvision.datasets.Caltech256(root='./data', download=True, transform=transform)
# Get the indices of the data for splitting it later.
full_indices = full_dataset.index



# Generate breaks to generate a list of lists containing indices for each category.
brks = [i for i in range(1,len(full_indices)) if full_indices[i] < full_indices[i-1]]
split_indices = [full_indices[x:y] for x,y in zip([0]+brks,brks+[None])]

# We do not want to include the 257th category, clutter, because they don't use it in the paper, so 
split_indices = split_indices[:256]

# Get 60 random indices from each category.
rand_indices = [sorted(random.sample(i, k = 60)) for i in split_indices]
# Concatenate the list of lists to a list.
train_indices = list(itertools.chain.from_iterable(rand_indices))

assert len(train_indices) == 256*60

# Get the indices not included in the train indices.
get_test_indices = [list(set(split_indices[i]) - set(rand_indices[i])) for i in range(len(split_indices))]
# Concatenate the list of lists to a list.
test_indices = list(itertools.chain.from_iterable(get_test_indices))


train_dataset = torch.utils.data.Subset(full_dataset, train_indices)

test_dataset = torch.utils.data.Subset(full_dataset, test_indices)


batch_sze = 8
batched_train_data = torch.utils.data.DataLoader(train_dataset, batch_size = batch_sze, shuffle = True)


#images, labels = next(iter(batched_train_data))
"""
img_t = np.transpose(images[0], (1, 2, 0))

plt.imshow(img_t)
plt.show()
"""

cuda:0
Files already downloaded and verified


'\nimg_t = np.transpose(images[0], (1, 2, 0))\n\nplt.imshow(img_t)\nplt.show()\n'

<h3> Intuitive Mechanism Behind SIFT Clustering.</h3>

The code below shows how we came to arrive at the function "get_SR_coords()" later defined. Note that although this code shows the regions on the original image, the regions are actually applied to the output of the backbone CNN, in this case Resnet50, through RoI_Align().

In [42]:
""" Note that you input the image batch output from Resnet in the actual implmentation of getting the SR's. This is just for conceptual purposes. """




"""
def get_SRs_and_other_info(image_batch):

    # Convert a single image to numpy array of expected shape for cv inputs
    im2 = image_batch[0].detach().numpy().transpose(1,2,0)
    # Rescale data for expected type 
    im2 = (im2*255).astype(np.uint8)

    
    sift = cv.SIFT_create()
    kp, desc = sift.detectAndCompute(im2,None)

    # Sort keypoints by their response in descending order
    keypoints_lst = sorted(kp, key=lambda x: x.response, reverse=True)

    # Apply threshold implicitly by resticting to top 50% keypoints.
    num_keypoints_to_keep = int(.5*len(keypoints_lst))
    keypoints = keypoints_lst[:num_keypoints_to_keep]
    print(num_keypoints_to_keep)
    
    # Option to include associated descriptors
    #filtered_descriptors = np.array([desc[i] for i in range(len(keypoints)) if i < num_keypoints_to_keep])

    # Convert Keypoints to np Array
    pts = cv.KeyPoint_convert(keypoints)


    # Let kappa = 8 as the paper suggests.
    kappa = 8

    # Gaussian Mixture Model
    gmm_pts = GMM(n_components= kappa, covariance_type= 'full', init_params= 'kmeans').fit(pts)
    gmm_labels = gmm_pts.predict(pts)


    # Sort pts into clusters
    clusters = defaultdict(list)
    for pt, label in zip(pts, gmm_labels):
        clusters[label].append(pt)

    # Compute min and max values for each cluster to get the bounding boxes
    Prim_SRs = {}
    for label, points in clusters.items():
        xpts, ypts = zip(*points)
        Prim_SRs[label] = {
            'min_x': min(xpts),
            'max_x': max(xpts),
            'min_y': min(ypts),
            'max_y': max(ypts), 
            'Primary_SR': image_batch[0,:,int(min(ypts)):int(max(ypts)),int(min(xpts)):int(max(xpts))]
    }


    # Note that we can get the primary SR's by "diagonal" elements    
    Secon_SRs = {}
    for i, itemi in Prim_SRs.items():
        for j, itemj in Prim_SRs.items():
            if (i <= j):
                Secon_SRs[(i,j)] = {
                    'Secondary_SR': image_batch[0,:,int(min(itemi['min_y'], itemj['min_y'])):int(max(itemi['max_y'], itemj['max_y'])),int(min(itemi['min_x'], itemj['min_x'])):int(max(itemi['max_x'], itemj['max_x']))]
                }
            


    return pts, clusters, Secon_SRs



pts, clusters, SR_fam = get_SRs_and_other_info(images)

plt.imshow(img_t)
plt.title(f"{full_dataset.categories[labels[0]]}")
xpts, ypts = zip(*pts)
#plt.scatter(xpts, ypts, c = gmm_labels)

# Get points within first cluster


SR = SR_fam[(0,1)]['Secondary_SR']


clust1 = clusters[0]
x1pts, y1pts = zip(*clust1)
plt.scatter(x1pts, y1pts, c = 'b')
clust2 = clusters[1]
x2pts, y2pts = zip(*clust2)
plt.scatter(x2pts, y2pts, c = 'r')


plt.show()

crop_t = np.transpose(SR, (1, 2, 0))
plt.imshow(crop_t)
plt.show()
"""

'\ndef get_SRs_and_other_info(image_batch):\n\n    # Convert a single image to numpy array of expected shape for cv inputs\n    im2 = image_batch[0].detach().numpy().transpose(1,2,0)\n    # Rescale data for expected type \n    im2 = (im2*255).astype(np.uint8)\n\n    \n    sift = cv.SIFT_create()\n    kp, desc = sift.detectAndCompute(im2,None)\n\n    # Sort keypoints by their response in descending order\n    keypoints_lst = sorted(kp, key=lambda x: x.response, reverse=True)\n\n    # Apply threshold implicitly by resticting to top 50% keypoints.\n    num_keypoints_to_keep = int(.5*len(keypoints_lst))\n    keypoints = keypoints_lst[:num_keypoints_to_keep]\n    print(num_keypoints_to_keep)\n    \n    # Option to include associated descriptors\n    #filtered_descriptors = np.array([desc[i] for i in range(len(keypoints)) if i < num_keypoints_to_keep])\n\n    # Convert Keypoints to np Array\n    pts = cv.KeyPoint_convert(keypoints)\n\n\n    # Let kappa = 8 as the paper suggests.\n    kappa =

<h3>Showing Some Backbone</h3>

We get our CNN backbone, Resnet50, weights from PyTorch, and thanks to their improved training algorithm, these weights are improved over the original Resnet50 weights, later likely allowing our model to get 98.3% accuracy, higher than the 96% mean in the paper. Note that a lot of the commented out sections were used for debugging and/or to be sure that each module matches the number of paramters from the original paper.

In [43]:
""" 
    PyTorch has the architecture for ResNet50 already available with the best pre-trained ImageNet weights. 
    We use Pyotrch's improved training weights instead of the original weights to better capture the features.
"""
class CustomResNet50(nn.Module):
    def __init__(self, device):
        super(CustomResNet50, self).__init__()
        resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(device)
        
        # Extract the first 5 blocks (Note that these are the conv1, conv2_x,..., conv_5x layers).
        self.get_blocks = nn.Sequential(*list(resnet50.children())[:8])
        
    
    def forward(self, x):
        return self.get_blocks(x)

#res_map = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(device)
#print(list(res_map.children())[7])


"""
res_feature_map = CustomResNet50(device).to(device)


# Outputs "torch.Size([Batch_sze, 2048, 7, 7])", so C = 256
resnet_output = res_feature_map(images.type(torch.FloatTensor).to(device))
print(resnet_output.size())

model_parameters = filter(lambda p: p.requires_grad, res_feature_map.parameters())
res_params = sum([np.prod(p.size()) for p in model_parameters])
print(f"{res_params:,}")
"""

'\nres_feature_map = CustomResNet50(device).to(device)\n\n\n# Outputs "torch.Size([Batch_sze, 2048, 7, 7])", so C = 256\nresnet_output = res_feature_map(images.type(torch.FloatTensor).to(device))\nprint(resnet_output.size())\n\nmodel_parameters = filter(lambda p: p.requires_grad, res_feature_map.parameters())\nres_params = sum([np.prod(p.size()) for p in model_parameters])\nprint(f"{res_params:,}")\n'

<h3> Getting the Semantic Region Coordinates</h3>

This uses a very similar mechanism to group clusters as demonstrated in the commented our cell. One modification is that we want to get the lower-left and upper-right coords of the bounding boxes to feed into RoI_Align later. Note that RoI Align uses bilinear interpolation, and this is the bilinear pooling of the features mentioned in the paper.

In [44]:
"""
We now get the coordinates of the SR's. This takes as input the image batch and outputs a list of tensors of different sizes.
"""

def get_SRs_coords(image_batch, loc_device, kappa=8):
    batch_SRs = []
    region_counts = []

    B, C, H, W = image_batch.size()

    for k in range(B):
        # Convert a single image to numpy array of expected shape for cv inputs
        copy_tens = image_batch[k].clone().cpu()
        im2 = copy_tens.detach().numpy().transpose(1, 2, 0)
        # Rescale data for expected type for CV 
        im2 = (im2 * 255).astype(np.uint8)

        sift = cv.SIFT_create()
        kp, desc = sift.detectAndCompute(im2, None)

        # Sort keypoints by their response in descending order
        keypoints_lst = sorted(set(kp), key=lambda x: x.response, reverse=True)

        Secon_SRs = []
        # Regardless of how many kp are detected, we want to include the whole image.
        Secon_SRs.append(torch.tensor([0, H - 1, 0, W - 1]))

        if len(keypoints_lst) > 1:
            num_keypoints_to_keep = max(int(.5 * len(keypoints_lst)), 1)
            if len(keypoints_lst) < 9:
                num_keypoints_to_keep = len(keypoints_lst)
                kappa = 2

            keypoints = keypoints_lst[:num_keypoints_to_keep]
            pts = cv.KeyPoint_convert(keypoints)

            unique_pts = np.unique(pts, axis=0)
            kappa_loc = min(kappa, len(unique_pts))

            if kappa_loc > 1:  # Only fit GMM if there are at least 2 unique points
                gmm_pts = GMM(n_components=kappa_loc, covariance_type='full', init_params='kmeans').fit(unique_pts)
                gmm_labels = gmm_pts.predict(pts)

                # Sort pts into clusters
                clusters = defaultdict(list)
                for pt, label in zip(pts, gmm_labels):
                    clusters[label].append(pt)

                # Compute min and max values for x and y coords of each cluster to get the bounding boxes
                Prim_SRs = {}
                for label, points in clusters.items():
                    xpts, ypts = zip(*points)
                    Prim_SRs[label] = {
                        'min_x': min(xpts),
                        'max_x': max(xpts),
                        'min_y': min(ypts),
                        'max_y': max(ypts),
                    }

                # Note that we can get the primary SR's via "diagonal" elements (when i = j )   
                for i, itemi in Prim_SRs.items():
                    for j, itemj in Prim_SRs.items():
                        if i <= j:
                            x1, x2 = min(itemi['min_y'], itemj['min_y']), max(itemi['max_y'], itemj['max_y'])
                            y1, y2 = min(itemi['min_x'], itemj['min_x']), max(itemi['max_x'], itemj['max_x'])

                            Secon_SRs.append(torch.tensor([x1, y1, x2, y2]))

        tens = torch.stack(Secon_SRs).to(loc_device)
        batch_SRs.append(tens)
        region_counts.append(tens.size(0))

    return batch_SRs, region_counts

# List of size B X different R+1 X 4
#SR_list, R_counts = get_SRs_coords(images, device, 5)


<h3> Creating the Necessary Modules </h3>

We implement AG-Net by building it piece-by-piece then assembling into one Torch module later for training. Note that if you use a different backbone CNN, you will possibly need to adjust the scale of the RoI and the number of input channels. As before, we leave some of the commented-out code to "show our work".

In [45]:

""" Given the feature maps, we now want to begin the intra-attention mechanism. """

class IntraSelfAttn(nn.Module):
    """ 
        Self Attention Layer adapted from SAGAN Implementation available at https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
    """
    def __init__(self, in_channels = 2048, B = 8):
        super(IntraSelfAttn,self).__init__()
        self.chanel_in = in_channels
        
        self.query_conv = nn.Conv2d(in_channels = in_channels , out_channels = in_channels//8 , kernel_size= 1) # f(x)
        self.key_conv = nn.Conv2d(in_channels = in_channels , out_channels = in_channels//8 , kernel_size= 1) # g(x)
        self.value_conv = nn.Conv2d(in_channels = in_channels , out_channels = in_channels , kernel_size= 1) # h(x)
        self.delta = nn.Parameter(torch.zeros(1)) # Delta initialized to be 0

        self.softmax  = nn.Softmax(dim=-1) 
     def forward(self,x):
        """
            inputs :
                x : input feature maps (B x C x W x H)
            returns :
                out : self attention value + input feature (same size)
        """
        B,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(B,-1,width*height).permute(0,2,1) # B x C x W*H
        proj_key =  self.key_conv(x).view(B,-1,width*height) # B x C x W*H
        energy =  torch.bmm(proj_query,proj_key) 
        attention = self.softmax(energy) # B x W*H x W*H 
        proj_value = self.value_conv(x).view(B,-1,width*height) # B x C x W*H

        out = torch.bmm(proj_value,attention.permute(0,2,1) ).view(B,C,width,height)
        
        out = self.delta*out + x
        return out

#Intra = IntraSelfAttn().to(device)
#Intra_out = Intra(resnet_output)
#model_parameters = filter(lambda p: p.requires_grad, Intra.parameters())
#intra_params = sum([np.prod(p.size()) for p in model_parameters])
#print(f"Intra params: {intra_params:,}")
#print(Intra_out.size())

"""  
roi_align will take as input the Intra-self attention features and 
an input parameter boxes which takes in a list of box tensors [batch_idx, x1, y1, x2, y2]. 
obtained with get_SRs(). 
If a batch is passed, we must have that the first column of boxes contains the 
index of the corresponding element in the batch.


Returns a tensor of size sum(R_i + 1) x output_size x output_size, where R_i is the number of SR's detected in image i from a batch.
"""

# SR_list gets coordinates from 224x224 images, but we want to assign them to 56 x 56 since these are the Resnet50 output dimensions.
#scale = 7/224

#roi = roi_align(Intra_out, SR_list, output_size= 7, spatial_scale= scale).to(device)

#print(roi.size())



class SE_Residual(nn.Module):

    """ Takes in batch of bilinearly pooled features. Outputs tensor of same size. Rp1 = R + 1. """
    def __init__(self, channels = 2048, Rp1 = 37):
        super(SE_Residual, self).__init__()
        self.SE = nn.ModuleList([SEBlock(channels=channels, reduction=16) for i in range(Rp1)])

    def forward(self, feature_tens, region_counts):
        
        B = len(region_counts)
        
        
        feature_list = torch.split(feature_tens, region_counts, dim = 0)
        outputs = []
        for i in range(B):
            for j in range(region_counts[i]):
                x = self.SE[j](feature_list[i][j])
                resid = x + feature_list[i][j].unsqueeze(0)
                outputs.append(resid)
                

        return torch.concat(outputs, dim=0)

#channel_num = roi.size()[1]
#SE_res = SE_Residual().to(device)
#model_parameters = filter(lambda p: p.requires_grad, SE_res.parameters())
#se_params = sum([np.prod(p.size()) for p in model_parameters])
#print(f"SE res params: {se_params:,}")
#SE_out = SE_res(roi, R_counts)


class InterSelfAttn(nn.Module):
    def __init__(self, in_channels = 2048, B = 8):
        super(InterSelfAttn, self).__init__()
        self.in_channels = in_channels
        
        # Shared 1x1 convolutions for parameter efficiency
        self.W_u = torch.nn.ModuleList([nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)  for i in range(B)])
        self.W_u_prime = torch.nn.ModuleList([nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)  for i in range(B)])
        self.W_m = torch.nn.ModuleList([nn.Conv2d(in_channels // 8, 1, kernel_size=1)  for i in range(B)])
        self.W_alpha = torch.nn.ModuleList([nn.Conv2d(in_channels, 1, kernel_size=1)  for i in range(B)])
        
    def forward(self, feature_tens, region_counts):
        B = len(region_counts)
        r = feature_tens.size(-1)

        feature_list = torch.split(feature_tens, region_counts)
        
        
        outputs = []
        for i in range(B):
            R = region_counts[i]
            regions = feature_list[i]  # R x C x H x W
            
            
            # Compute u_{r,r'}
            u_r = self.W_u[i](regions).unsqueeze(0)  # 1 x R x C/8 x H x W
            u_r_prime = self.W_u_prime[i](regions).unsqueeze(1)  # R x 1 x C/8 x H x W
            u_r_r_prime = torch.tanh(u_r + u_r_prime)  # R x R x C/8 x H x W
            
            # Compute m_{r,r'}
            m_r_r_prime = self.W_m[i](u_r_r_prime.view(R * R, -1, r, r)).view(R, R, 1, r, r)  # R x R x 1 x H x W 
            m_r_r_prime = torch.sigmoid(m_r_r_prime)
            
            # Aggregate attentional features
            alpha_r = torch.sum(m_r_r_prime * regions.unsqueeze(1), dim=0)  # R x C x H x W
            
            # Compute weights w_r
            w_r = self.W_alpha[i](alpha_r).view(R, -1)  
            w_r = F.softmax(w_r, dim=0).view(R, 1, r, r)  # R x 1 x H x W
            
            # Combine regional features
            f_hat = torch.sum(alpha_r * w_r, dim=0)  # C x H x W
            outputs.append(f_hat)
        
        outputs = torch.stack(outputs, dim=0)  # B x C x H x W
        
        return outputs
#k = 5
#kappa_max = k*(k+1)//2 + 1

#print(kappa_max)
#inter = InterSelfAttn().to(device)


#model_parameters = filter(lambda p: p.requires_grad, inter.parameters())
#inter_params = sum([np.prod(p.size()) for p in model_parameters])
#print(f"Inter params: {inter_params:,}")
#inter_out = inter(SE_out, R_counts)
#print(inter_out.size())

In [46]:
class Classify(nn.Module):
    def __init__(self, C = 2048, num_classes = 256):
        super(Classify, self).__init__()
        self.softmax = nn.Softmax(dim = 1)
        # Layers for classification
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.gmp = nn.AdaptiveMaxPool2d((1, 1))
        
        self.Womega = nn.Linear(C, 1)
        self.bomega = nn.Parameter(torch.zeros(1))
        self.fc = nn.Linear(C, num_classes)

    def forward(self, x):
        """
            inputs:
                x: input batch of features (B x C x 7 x 7)
            returns:
                out: class probabilities (B x num_classes)
        """
        B, C, H, W = x.size()

        # Apply GAP and GMP
        f_gap = self.gap(x).view(B, C)
        f_gmp = self.gmp(x).view(B, C)

        # Calculate omega and combine features
        omega = self.softmax(self.Womega(f_gap) + self.bomega)
        F = omega * f_gmp + (1 - omega) * f_gap
        
        # Classification
        class_probs = self.fc(F)

        return class_probs
        
"""    
c = Classify().to(device)
c_out = c(inter_out)
#print(c_out.size())
model_parameters = filter(lambda p: p.requires_grad, c.parameters())
classify_params = sum([np.prod(p.size()) for p in model_parameters])
print(f"Classify params: {classify_params:,}")
"""

'    \nc = Classify().to(device)\nc_out = c(inter_out)\n#print(c_out.size())\nmodel_parameters = filter(lambda p: p.requires_grad, c.parameters())\nclassify_params = sum([np.prod(p.size()) for p in model_parameters])\nprint(f"Classify params: {classify_params:,}")\n'

In [47]:
"""
print(f"Resnet params: {res_params:,}")

total_params_baseline = classify_params + res_params 
# Should be around 23.62M, but our count is 24.03M 
print(f"Baseline params: {total_params_baseline:,}")

# Should be around 43.02M
total_params_no_attn = total_params_baseline + se_params  
print(f"-Attn params: {total_params_no_attn:,}")

# Should be around 54.79M
total_params = total_params_no_attn + inter_params+ intra_params
print(f"+Attn params: {total_params:,}")
"""

'\nprint(f"Resnet params: {res_params:,}")\n\ntotal_params_baseline = classify_params + res_params \n# Should be around 23.62M, but our count is 24.03M \nprint(f"Baseline params: {total_params_baseline:,}")\n\n# Should be around 43.02M\ntotal_params_no_attn = total_params_baseline + se_params  \nprint(f"-Attn params: {total_params_no_attn:,}")\n\n# Should be around 54.79M (not including last layers)\ntotal_params = total_params_no_attn + inter_params+ intra_params\nprint(f"+Attn params: {total_params:,}")\n'

<h3> Putting Everything Together </h3>

In [48]:
class AG_Net(nn.Module):
    def __init__(self, kappa, batch_size_loc, device_loc):
        super(AG_Net, self).__init__()
        self.dev = device_loc
        self.batch = batch_size_loc
        self.kap = kappa
        self.kappa_max = 1+kappa*(kappa+1)//2
        self.c_res = CustomResNet50(device).to(device)
        self.intra = IntraSelfAttn().to(device_loc)
        self.se_res = SE_Residual(channels= 2048, Rp1= self.kappa_max).to(device_loc)
        self.inter = InterSelfAttn(B= batch_size_loc).to(device_loc)
        self.classify = Classify().to(device_loc)
    
    def forward(self, x):
        res_feats = self.c_res(x)
        SR_list, region_counts = get_SRs_coords(x, self.dev, self.kap)
        x = self.intra(res_feats)
        x = roi_align(x, SR_list, output_size= 7, spatial_scale= 7/224).to(device)
        x = self.se_res(x, region_counts)
        x = self.inter(x, region_counts)
        x = self.classify(x)        
        return x

kappa_global = 8
net = AG_Net(kappa = kappa_global, batch_size_loc=batch_sze, device_loc= device).to(device)
#net_out = net(SR_tens)


#model_parameters = filter(lambda p: p.requires_grad, net.parameters())
#params = sum([np.prod(p.size()) for p in model_parameters])
#print(f"{params:,}")

#Should be about 56.79M
#print(f"{params:,}")

<h3> Training the Model (Possibly From Checkpoints)</h3>

In [49]:
import os


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-5, momentum=0.99)


def load_checkpoint(model, optimizer, train_indices, test_indices, device, filename='AG_net_weights2.pth'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    loaded_flag = False
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename, map_location=device)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        train_indices = checkpoint['train_indices']
        test_indices = checkpoint['test_indices']
        loaded_flag = True
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, train_indices, test_indices, loaded_flag

# Loading the saved model
 
net, optimizer, start_epoch, train_indices, test_indices, loaded_flag = load_checkpoint(net, optimizer, train_indices, test_indices, device)
net = net.to(device)

if loaded_flag:
    # now individually transfer the optimizer parts...
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

    train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    test_dataset = torch.utils.data.Subset(full_dataset, test_indices)
    batched_train_data = torch.utils.data.DataLoader(train_dataset, batch_size = batch_sze, shuffle = True)

net.train()
end_epoch = 50

for epoch in range(start_epoch, end_epoch):  # loop over the dataset multiple times
    running_loss = 0.0
    if (epoch > 23):
        optimizer = optim.SGD(net.parameters(), lr=1e-6, momentum=0.99)
    for i, batch in enumerate(batched_train_data):
        
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = batch
        
        inputs, labels = inputs.to(device), labels.type(torch.LongTensor).to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs).type(torch.FloatTensor).to(device)
        #print(outputs, labels)
        loss = criterion(outputs, labels)
        #loss.requires_grad = True
        loss.backward()
        optimizer.step()

        # print statistics
        #"""
        running_loss += loss.item()
        print_num = 1920
        if i % print_num == (print_num - 1):    # print every print_num batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / print_num:.3f}')
            running_loss = 0.0
        #"""
print('Finished Training')

=> no checkpoint found at 'AG_net_weights2.pth'
[1,  1920] loss: 0.496
[2,  1920] loss: 0.029
[3,  1920] loss: 0.017
[4,  1920] loss: 0.009
[5,  1920] loss: 0.009
[6,  1920] loss: 0.007
[7,  1920] loss: 0.003
[8,  1920] loss: 0.007
[9,  1920] loss: 0.003
[10,  1920] loss: 0.003
[11,  1920] loss: 0.003
[12,  1920] loss: 0.004
[13,  1920] loss: 0.002
[14,  1920] loss: 0.002
[15,  1920] loss: 0.002
[16,  1920] loss: 0.002
[17,  1920] loss: 0.002
[18,  1920] loss: 0.001
[19,  1920] loss: 0.001
[20,  1920] loss: 0.004
[21,  1920] loss: 0.001
[22,  1920] loss: 0.001
[23,  1920] loss: 0.002
[24,  1920] loss: 0.001
[25,  1920] loss: 0.001
[26,  1920] loss: 0.001
[27,  1920] loss: 0.000
[28,  1920] loss: 0.001
[29,  1920] loss: 0.001
[30,  1920] loss: 0.001
[31,  1920] loss: 0.001
[32,  1920] loss: 0.001
[33,  1920] loss: 0.002
[34,  1920] loss: 0.000
[35,  1920] loss: 0.001
[36,  1920] loss: 0.001
[37,  1920] loss: 0.001
[38,  1920] loss: 0.001
[39,  1920] loss: 0.001
[40,  1920] loss: 0.001
[

<h3> Saving Training Progress </h3>

In [50]:
# Get the current working directory
notebook_dir = os.getcwd()

# Define the filename for saving the model
filename = 'AG_net_weights2.pth'

# Construct the full path to save the model
save_path = os.path.join(notebook_dir, filename)

state = {'epoch': epoch + 1, 'state_dict': net.state_dict(),
             'optimizer': optimizer.state_dict(), 'train_indices': train_indices, 'test_indices': test_indices}
torch.save(state, filename)

<h3> Testing the Model </h3>

In [51]:
# Define a DataLoader for the test dataset
batch_size_test = 8
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)

# Set the model to evaluation mode
net.eval()

correct = 0
total = 0

# Disable gradient computation during evaluation
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = net(inputs)
        
        blah, predicted = torch.max(outputs, 1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy on the test dataset: {accuracy:.2f}%')

Accuracy on the test dataset: 98.34%
