In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os

In [15]:
class expert(nn.Module):
    def __init__(self,path="initial_weights.pth"):
        super(expert, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)
        #self._initialize_weights()
        if os.path.exists(path):
          self.load_weights(path)
        else:
          self._log_initial_weights()

    def _log_initial_weights(self):
      initial_weights = {name: param.clone().detach() for name, param in self.named_parameters()}
      torch.save(initial_weights, 'initial_weights.pth')
      print("Initial weights saved to 'initial_weights.pth'")

    def load_weights(self,path):
      weights = torch.load('initial_weights.pth')
      for name, param in self.named_parameters():
          param.data.copy_(weights[name])
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [16]:
BATCH_SIZE = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

network = expert().to(device)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(network.parameters(),lr=0.001)

loss_list = []
N_epochs = 20

Files already downloaded and verified
Files already downloaded and verified
Initial weights saved to 'initial_weights.pth'


In [17]:
for epoch in range(N_epochs):
     combined_loss = 0
     for inputs,labels in train_loader:
         inputs = inputs.to(device)
         labels = F.one_hot(labels,num_classes=10).float().to(device)
         #print(labels)
         #input()
         pred = network(inputs)
         loss = criterion(pred,labels)
         combined_loss = combined_loss + loss.item()
         optim.zero_grad()
         loss.backward()
         optim.step()
     loss_list.append(combined_loss)
     print("epoch:",epoch,"loss:",combined_loss)

print("Training complete")
path = "./single_expert_model"
torch.save(network, path)

epoch: 0 loss: 732.3535149097443
epoch: 1 loss: 609.5219674110413
epoch: 2 loss: 566.2838814258575
epoch: 3 loss: 535.9174718856812
epoch: 4 loss: 510.3687844276428
epoch: 5 loss: 492.48207956552505
epoch: 6 loss: 474.77973771095276
epoch: 7 loss: 457.45296412706375
epoch: 8 loss: 444.54776549339294
epoch: 9 loss: 433.0480182170868
epoch: 10 loss: 419.88516598939896
epoch: 11 loss: 412.8006114959717
epoch: 12 loss: 400.78377175331116
epoch: 13 loss: 392.8532781600952
epoch: 14 loss: 384.0487329363823
epoch: 15 loss: 376.6059029698372
epoch: 16 loss: 369.9014803171158
epoch: 17 loss: 362.9875689148903
epoch: 18 loss: 354.13904958963394
epoch: 19 loss: 348.69776314496994
Training complete


In [18]:
def check(pred,labels,corr):
    pred = torch.argmax(pred,dim=1)
    for i in range(len(pred)):
        if pred[i] == labels[i]:
            corr = corr + 1
    return corr

In [19]:
#Testing
corr=0
for inputs,labels in test_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    #print(labels)
    pred = network(inputs)
    corr = check(pred,labels,corr)
print("accuracy:",corr/len(test_dataset))

accuracy: 0.6151
