In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
classes =  ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly',
           'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup',
           'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard',
           'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
           'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum',
           'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
           'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
           'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle',
           'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']

In [None]:

class DWN(nn.Module):
    def __init__(self):
        super(DWN, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, 3, 1, 1)
        self.bn = nn.BatchNorm2d(96)
        self.conv2 = nn.Conv2d(96, 96, 3, 1, 1)

        self.conv2n = nn.Conv2d(96, 32, 1)
        self.bn2n = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 32, 5, stride=1, padding=2)

        self.conv4 = nn.Conv2d(32, 1, 1, 1, 0)
        self.bn4 = nn.BatchNorm2d(1)

        self.activation = nn.ReLU()
        self.fc1 = nn.Linear(32*32, 100)

    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(self.bn(x))
        x = self.conv2(x)
        x = self.activation(self.bn(x))

        x = self.conv2n(x)
        x = self.activation(self.bn2n(x))

        x = self.conv3(x)
        x = self.activation(self.bn2n(x))

        x = self.conv4(x)
        x = self.activation(self.bn4(x))

        x = torch.flatten(x, 1)
        x = self.activation(self.fc1(x))

        return x



In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Define device:', device)
PATH = './best_model.pth'

def train(ep=5, lr=0.001, ml=True):
    net = DWN().to(device)
    if ml == True: net.load_state_dict(torch.load(PATH))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=0.0001, momentum=0.8)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    batch_size = 64

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=0)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=0)

    for name, parameters in net.named_parameters():
        print(name, ':', parameters.shape)

    print(sum(p.numel() for p in net.parameters()))
    print(net)
    print('=======================================================================================================')

    best_acc = 0
    for epoch in range(ep):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}', end=' | ')
                running_loss = 0.0

                correct = 0
                total = 0
                # since we're not training, we don't need to calculate the gradients for our outputs
                with torch.no_grad():
                    for data in testloader:
                        img, lab = data[0], data[1]

                        # calculate outputs by running images through the network
                        outputs = net(img.to(device))
                        # the class with the highest energy is what we choose as prediction
                        _, predicted = torch.max(outputs.data, 1)
                        total += lab.size(0)
                        correct += (predicted.to('cpu') == lab).sum().item()

                print(f'Accuracy test: {100 * correct // total} %')

                if correct > best_acc:
                    best_acc = correct
                    torch.save(net.state_dict(), './best_model.pth')

    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    # again no gradients needed
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0], data[1]
            outputs = net(images.to(device))
            _, predictions = torch.max(outputs, 1)

            # collect the correct predictions for each class
            for label, prediction in zip(labels, predictions.to('cpu')):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

    #torch.save(net.state_dict(), PATH)
    print('Finished Training')


Define device: cuda:0


In [None]:
train(ep=5, lr=0.01, ml=0)

Files already downloaded and verified
Files already downloaded and verified
conv1.weight : torch.Size([96, 3, 3, 3])
conv1.bias : torch.Size([96])
bn.weight : torch.Size([96])
bn.bias : torch.Size([96])
conv2.weight : torch.Size([96, 96, 3, 3])
conv2.bias : torch.Size([96])
conv2n.weight : torch.Size([32, 96, 1, 1])
conv2n.bias : torch.Size([32])
bn2n.weight : torch.Size([32])
bn2n.bias : torch.Size([32])
conv3.weight : torch.Size([32, 32, 5, 5])
conv3.bias : torch.Size([32])
conv4.weight : torch.Size([1, 32, 1, 1])
conv4.bias : torch.Size([1])
bn4.weight : torch.Size([1])
bn4.bias : torch.Size([1])
fc1.weight : torch.Size([100, 1024])
fc1.bias : torch.Size([100])
217255
DWN(
  (conv1): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2n): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
  (bn2n): Ba