<a href="https://colab.research.google.com/github/STKalinowski/CurriculumMiniExperiment/blob/main/CurriculumMiniExperiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.15.5-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.27.1-py2.py3-none-any.whl (211 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.7/211.7 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools (from wandb)
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.2-cp310-cp310-manylinux_2_5_x86_64.manyli

In [2]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import wandb

In [3]:
wandb.login()
if torch.cuda.is_available():
  device='cuda'
else:
  device = 'cpu'

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Basic cnn model for classification
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 102866194.14it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
def train(net, inputloader, criterion, optimizer, epochs=5, aug=None):
  wandb.init(project='cifar-10-curriculum')
  trainLoader = inputloader
  for epoch in range(epochs):
      running_loss = 0.0
      if aug != None and epoch > epochs/2:
        trainLoader = aug
      for i, data in enumerate(trainloader, 0):
          inputs, labels = data
          inputs = inputs.to(device)
          labels = labels.to(device)

          optimizer.zero_grad()
          outputs = net(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          running_loss += loss.item()
            # Log the loss to wandb
          if i % 500 == 499:
              wandb.log({"loss": running_loss / 500})
              print(f"Loss: {running_loss/500}")
              running_loss = 0.0
  print('Finished Training')

def test(net, testloader):
  correct = 0
  total = 0
  with torch.no_grad():
      for data in testloader:
          images, labels = data
          images = images.to(device)
          labels = labels.to(device)
          outputs = net(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

  print('Accuracy of the network on the 10000 test images: %d %%' % (
      100 * correct / total))

  # Log the accuracy to wandb
  wandb.log({"accuracy": 100 * correct / total})

In [7]:
# Normal Training
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print("Training with Normal Data")
train(net, trainloader, criterion, optimizer)
test(net, testloader)

# Transformed Training
transformed_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    download=True, transform=transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

trainloaderTransform = torch.utils.data.DataLoader(transformed_trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print("Training with Transformed Data")
train(net, trainloaderTransform, criterion, optimizer)
test(net, testloader)

# Curriculum Learning: Start training on normal examples then on transformed examples
# Curriculum Learning: Start training on normal examples then on transformed examples
normalLoader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
augmentedLoader = torch.utils.data.DataLoader(transformed_trainset, batch_size=4, shuffle=True, num_workers=2)

net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print("Training with Curriculum Learning: Normal Data")
train(net, trainloader, criterion, optimizer, epochs=3)
train(net, trainloaderTransform, criterion, optimizer, epochs=2)
test(net,testloader)
# Save the model in wandb after training
torch.save(net.state_dict(), 'model.pth')
wandb.save('model.pth')

Training with Normal Data


VBox(children=(Label(value='0.240 MB of 0.240 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
loss,█▇▆▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁

0,1
accuracy,55.53
loss,1.34182


Loss: 2.297978175163269
Loss: 2.2255441834926604
Loss: 2.0862493991851805
Loss: 1.9798955466747283
Loss: 1.9228882415294648
Loss: 1.8458090317249298
Loss: 1.7790764880180359
Loss: 1.766755762219429
Loss: 1.6958447407484054
Loss: 1.7041249678134918
Loss: 1.6774685114622117
Loss: 1.6313792017698288
Loss: 1.6214277245998383
Loss: 1.6080850701332092
Loss: 1.5686259409189225
Loss: 1.5471848529577255
Loss: 1.5465326497554779
Loss: 1.5439244900345803
Loss: 1.4967546699643135
Loss: 1.5221444606781005
Loss: 1.4884879332780838
Loss: 1.4826671985983848
Loss: 1.479960719048977
Loss: 1.438953110039234
Loss: 1.4256020321846008
Loss: 1.4092968809604645
Loss: 1.3693743135929108
Loss: 1.4022363979220391
Loss: 1.4267043988704682
Loss: 1.4050057340264321
Loss: 1.3603072600364685
Loss: 1.3503976846635342
Loss: 1.3689780922532082
Loss: 1.3836105006635189
Loss: 1.3623139261901378
Loss: 1.353773936122656
Loss: 1.3506781966090202
Loss: 1.315790104418993
Loss: 1.3277869950085879
Loss: 1.3323861611485481
Loss: 

0,1
accuracy,▁
loss,█▆▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁

0,1
accuracy,60.43
loss,1.01064


Loss: 2.3024748015403746
Loss: 2.298315933227539
Loss: 2.272396462440491
Loss: 2.1058586480617523
Loss: 1.997493007659912
Loss: 1.8945587503910064
Loss: 1.8201909369230271
Loss: 1.773915493130684
Loss: 1.7089863454103469
Loss: 1.7235103439092636
Loss: 1.6530662940740586
Loss: 1.6406683470010757
Loss: 1.574110768198967
Loss: 1.5913887355327607
Loss: 1.592277640938759
Loss: 1.5618173005580902
Loss: 1.534687442779541
Loss: 1.4937778004109858
Loss: 1.4418753499388695
Loss: 1.4621395069360732
Loss: 1.4511635722517968
Loss: 1.4575303582549095
Loss: 1.4573650490045547
Loss: 1.439979199588299
Loss: 1.4073495568633079
Loss: 1.37481294798851
Loss: 1.3488082140088082
Loss: 1.3808560036420823
Loss: 1.3419456604123114
Loss: 1.3481253707259893
Loss: 1.3027082492709159
Loss: 1.356162904381752
Loss: 1.308189390540123
Loss: 1.3240607059597969
Loss: 1.3053435969650746
Loss: 1.294950517207384
Loss: 1.3151400989890099
Loss: 1.2759876271784305
Loss: 1.295755820453167
Loss: 1.292619519084692
Loss: 1.2550458

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
loss,█▇▆▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▂▁▁▁▁▁▂▁▁

0,1
accuracy,61.0
loss,1.00575


Loss: 2.3031026372909547
Loss: 2.286165090560913
Loss: 2.1935161893367767
Loss: 2.0415911247730256
Loss: 1.9717539274692535
Loss: 1.8874446133375167
Loss: 1.8500157750844954
Loss: 1.8151082367897033
Loss: 1.7307893394231797
Loss: 1.7040085604190827
Loss: 1.6609907054901123
Loss: 1.6178744409680366
Loss: 1.5923722318410873
Loss: 1.6022363498210908
Loss: 1.5735687564611436
Loss: 1.5700633655786513
Loss: 1.567354952275753
Loss: 1.5240212404727935
Loss: 1.5494854490756989
Loss: 1.5378765276074409
Loss: 1.4550044565796851
Loss: 1.4558188856840133
Loss: 1.484426328957081
Loss: 1.4613259699940682
Loss: 1.419522746503353
Loss: 1.4194828647971154
Loss: 1.410789497792721
Loss: 1.4252662521004678
Loss: 1.408457118988037
Loss: 1.3910288268923758
Loss: 1.3843841537833215
Loss: 1.394004211127758
Loss: 1.3532126961946487
Loss: 1.3691512407660484
Loss: 1.3311749114096165
Loss: 1.3914020463526249
Loss: 1.3427557989805936
Loss: 1.3521537560522556
Loss: 1.3552615171074867
Loss: 1.2756456194221975
Loss: 1

0,1
loss,██▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,1.19954


Loss: 1.0869918142855168
Loss: 1.110479244440794
Loss: 1.1625303364247084
Loss: 1.1173639398962258
Loss: 1.1216248940080404
Loss: 1.1455185522735118
Loss: 1.1347975483983754
Loss: 1.1382238976210355
Loss: 1.127514985471964
Loss: 1.1708383712917567
Loss: 1.1317302220612764
Loss: 1.1335107963830233
Loss: 1.1367434515058994
Loss: 1.1332891656011344
Loss: 1.1476827265471221
Loss: 1.1403328889012336
Loss: 1.1476793823093177
Loss: 1.1460343796014785
Loss: 1.1133862878382206
Loss: 1.1103349249958991
Loss: 1.1480095807760955
Loss: 1.1181915133595466
Loss: 1.11602081874758
Loss: 1.1323489962220192
Loss: 1.0757169922105967
Loss: 1.027612505711615
Loss: 1.0735466218218208
Loss: 1.0723423093110322
Loss: 1.0831872099488973
Loss: 1.014936152525246
Loss: 1.0352993437796831
Loss: 1.0517954688668252
Loss: 1.0862175545543433
Loss: 1.052161943331361
Loss: 1.0827106609791517
Loss: 1.0391405949816108
Loss: 1.0665453854873777
Loss: 1.0783128101713955
Loss: 1.110246451575309
Loss: 1.0282973987851292
Loss: 1.

['/content/wandb/run-20230707_093303-o4zvkqbc/files/model.pth']