In [93]:
!pip install torch torchvision wandb -qqq

In [2]:
!git clone https://github.com/jgamper/intrinsic-dimensionality.git -q

In [109]:
import wandb
!wandb login
# wandb.init(project='id_test_toy', entity='nayannkote')

[34m[1mwandb[0m: Currently logged in as: [33mnayannkote[0m (use `wandb login --relogin` to force relogin)


In [112]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import importlib
fastfood = importlib.import_module("intrinsic-dimensionality.intrinsic.fastfood")

In [113]:
class FCNet(nn.Module):
    def __init__(self):
        super(FCNet, self).__init__()
        # 784–200–200–10 from uber-research paper
        self.fc1 = nn.Linear(784,200)
        self.fc2 = nn.Linear(200,200)
        self.fc3 = nn.Linear(200,10)

    def forward(self,x):
        x = nn.Flatten(start_dim=1)(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x,dim=1)

        return output

In [114]:
def train(
    model,
    device,
    train_loader,
    optimizer,
    epoch,
    log_interval=100
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx%log_interval == 0:
            print(f"Epoch:{epoch+1}\t\t Batch Number:{batch_idx}\t\t Loss:{loss.item()}")
            wandb.log({"batch loss":loss.item()})

def test(
    model,
    device,
    test_loader,
    test_batch_size,
    ID,
    log_interval=100
):
    model.eval()
    test_loss = 0
    num_correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss+=F.nll_loss(output,target)
            pred = output.argmax(dim=1,keepdim=True)
            num_correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= (len(test_loader)*test_batch_size) # Dividing by num_batchs*batch_size
    acc = num_correct*100/(len(test_loader)*test_batch_size)
    wandb.log({"test_loss":test_loss, "test_accuracy":acc}) 

    print(f"\nTest Loss: {test_loss}\t Accuarcy:{acc}%\n")

In [115]:
def main(
    ID,
    num_epochs,
    train_batch_size,
    test_batch_size,
    output_dir,
    learning_rate=1.0,
    gamma=0.1,
    use_cuda=True,
    seed=69,
    log_interval=100,
):
    
    # Use GPU or not
    use_cuda = torch.cuda.is_available() and use_cuda
    device = torch.device("cuda" if use_cuda else "cpu")

    # Setting seed
    torch.manual_seed(seed)

    # Setting train and test kwargs
    train_kwargs = {"batch_size":train_batch_size}
    test_kwargs = {"batch_size":test_batch_size}

    if use_cuda:
        cuda_kwargs={
            "num_workers":1, 
            "pin_memory":True, # Why pinning is faster, ref : https://forums.developer.nvidia.com/t/why-using-pinned-memory-is-faster/1948
            "shuffle":True
        }
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    # Defining image transform, ref : https://stackoverflow.com/questions/63746182/correct-way-of-normalizing-and-scaling-the-mnist-dataset
    transform = transforms.Compose(
        [
         transforms.ToTensor(), # Converts input image to 3 dim tensor and values b/w [0,1]
         transforms.Normalize((0.1307,), (0.3081, )) # Mean 0.1307 and StdDev 0.3081
        ]
    )

    # Getting train and test datasets
    train_data = datasets.MNIST("/content/", train=True, download=True, transform=transform)
    test_data = datasets.MNIST("/content/", train=False, download=True, transform=transform)

    # Dataloaders
    train_loader = torch.utils.data.DataLoader(train_data,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(test_data,**test_kwargs)

    # Model, Optimizer and Learning Rate Scheduler
    model = FCNet().to(device)
    if ID:
        model = fastfood.FastfoodWrap(model, intrinsic_dimension=ID, device=device)
    optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    # Training model
    print(f"Training model with ID: {ID}\n")
    for epoch in range(num_epochs):
        train(model,device,train_loader,optimizer,epoch,log_interval)
        test(model,device,test_loader,test_batch_size,ID)
        lr_scheduler.step()

    if output_dir!=None:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'epoch': epoch+1
        }, os.path.join(output_dir,f"mnist_fc_ID{ID}.pt"))

In [118]:
for ID in range(0,1001,100):
    # # Custom x axis
    # wandb.define_metric("ID")
    # # Values which are plotted with custom x axis
    # wandb.define_metric("test*", step_metric="ID")
    wandb.init(reinit=True, project='id_test_toy', entity='nayannkote', name=f"ID_{ID}", resume=None)
    main(ID=ID,num_epochs=1,train_batch_size=128,test_batch_size=128,output_dir="/content/")

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


Training model with ID: 0

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:0.18629644811153412
Epoch:1		 Batch Number:200		 Loss:0.1204913854598999
Epoch:1		 Batch Number:300		 Loss:0.20646442472934723
Epoch:1		 Batch Number:400		 Loss:0.18631988763809204

Test Loss: 0.0015600861515849829	 Accuarcy:92.72151898734177%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.18632
_runtime,15.0
_timestamp,1630052079.0
_step,5.0
test_loss,0.00156
test_accuracy,92.72152


0,1
batch loss,█▁▁▁▁
_runtime,▁▃▄▅▆█
_timestamp,▁▃▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 100

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:2.236151933670044
Epoch:1		 Batch Number:200		 Loss:2.097813367843628
Epoch:1		 Batch Number:300		 Loss:1.8429356813430786
Epoch:1		 Batch Number:400		 Loss:1.7076603174209595

Test Loss: 0.01227465458214283	 Accuarcy:49.61431962025316%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,1.70766
_runtime,59.0
_timestamp,1630052141.0
_step,5.0
test_loss,0.01227
test_accuracy,49.61432


0,1
batch loss,█▇▆▃▁
_runtime,▁▂▄▅▇█
_timestamp,▁▂▄▅▇█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 200

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:2.138045072555542
Epoch:1		 Batch Number:200		 Loss:1.7940938472747803
Epoch:1		 Batch Number:300		 Loss:1.4335778951644897
Epoch:1		 Batch Number:400		 Loss:1.336256504058838

Test Loss: 0.009130015037953854	 Accuarcy:62.78678797468354%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,1.33626
_runtime,58.0
_timestamp,1630052202.0
_step,5.0
test_loss,0.00913
test_accuracy,62.78679


0,1
batch loss,█▇▄▂▁
_runtime,▁▂▄▅▆█
_timestamp,▁▂▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 300

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.9123228788375854
Epoch:1		 Batch Number:200		 Loss:1.1883898973464966
Epoch:1		 Batch Number:300		 Loss:0.933940052986145
Epoch:1		 Batch Number:400		 Loss:0.9550138115882874

Test Loss: 0.006449082400649786	 Accuarcy:72.83425632911393%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.95501
_runtime,58.0
_timestamp,1630052264.0
_step,5.0
test_loss,0.00645
test_accuracy,72.83426


0,1
batch loss,█▆▂▁▁
_runtime,▁▂▄▅▇█
_timestamp,▁▂▄▅▇█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 400

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.7507984638214111
Epoch:1		 Batch Number:200		 Loss:0.8930684328079224
Epoch:1		 Batch Number:300		 Loss:0.8180958032608032
Epoch:1		 Batch Number:400		 Loss:0.7962849736213684

Test Loss: 0.00542991328984499	 Accuarcy:76.99762658227849%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.79628
_runtime,59.0
_timestamp,1630052326.0
_step,5.0
test_loss,0.00543
test_accuracy,76.99763


0,1
batch loss,█▅▁▁▁
_runtime,▁▂▄▅▇█
_timestamp,▁▂▄▅▇█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 500

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.6568373441696167
Epoch:1		 Batch Number:200		 Loss:0.7206133008003235
Epoch:1		 Batch Number:300		 Loss:0.6454060077667236
Epoch:1		 Batch Number:400		 Loss:0.7094405293464661

Test Loss: 0.004516534972935915	 Accuarcy:80.50830696202532%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.70944
_runtime,58.0
_timestamp,1630052388.0
_step,5.0
test_loss,0.00452
test_accuracy,80.50831


0,1
batch loss,█▅▁▁▁
_runtime,▁▂▄▅▆█
_timestamp,▁▂▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 600

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.3523584604263306
Epoch:1		 Batch Number:200		 Loss:0.6697553396224976
Epoch:1		 Batch Number:300		 Loss:0.5585227608680725
Epoch:1		 Batch Number:400		 Loss:0.6850382685661316

Test Loss: 0.00452557485550642	 Accuarcy:80.9434335443038%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.68504
_runtime,58.0
_timestamp,1630052450.0
_step,5.0
test_loss,0.00453
test_accuracy,80.94343


0,1
batch loss,█▄▁▁▂
_runtime,▁▂▄▅▆█
_timestamp,▁▂▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 700

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.3173428773880005
Epoch:1		 Batch Number:200		 Loss:0.5984770059585571
Epoch:1		 Batch Number:300		 Loss:0.549479067325592
Epoch:1		 Batch Number:400		 Loss:0.6874551773071289

Test Loss: 0.0039199963212013245	 Accuarcy:83.14873417721519%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.68746
_runtime,58.0
_timestamp,1630052511.0
_step,5.0
test_loss,0.00392
test_accuracy,83.14873


0,1
batch loss,█▄▁▁▂
_runtime,▁▂▄▅▇█
_timestamp,▁▂▄▅▇█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 800

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.1542048454284668
Epoch:1		 Batch Number:200		 Loss:0.49162694811820984
Epoch:1		 Batch Number:300		 Loss:0.5034583210945129
Epoch:1		 Batch Number:400		 Loss:0.5398048162460327

Test Loss: 0.003861360251903534	 Accuarcy:83.34651898734177%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.5398
_runtime,58.0
_timestamp,1630052573.0
_step,5.0
test_loss,0.00386
test_accuracy,83.34652


0,1
batch loss,█▄▁▁▁
_runtime,▁▂▄▅▆█
_timestamp,▁▂▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 900

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:1.0084731578826904
Epoch:1		 Batch Number:200		 Loss:0.5628271102905273
Epoch:1		 Batch Number:300		 Loss:0.6070367693901062
Epoch:1		 Batch Number:400		 Loss:0.5674901008605957

Test Loss: 0.003637476358562708	 Accuarcy:84.59256329113924%



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch loss,0.56749
_runtime,58.0
_timestamp,1630052634.0
_step,5.0
test_loss,0.00364
test_accuracy,84.59256


0,1
batch loss,█▃▁▁▁
_runtime,▁▂▄▅▆█
_timestamp,▁▂▄▅▆█
_step,▁▂▄▅▇█
test_loss,▁
test_accuracy,▁


Training model with ID: 1000

Epoch:1		 Batch Number:0		 Loss:2.2965505123138428
Epoch:1		 Batch Number:100		 Loss:0.9048476219177246
Epoch:1		 Batch Number:200		 Loss:0.4023589789867401
Epoch:1		 Batch Number:300		 Loss:0.48732519149780273
Epoch:1		 Batch Number:400		 Loss:0.5104105472564697

Test Loss: 0.003414246952161193	 Accuarcy:85.5320411392405%



In [119]:
baseline_model = FCNet()
trainable_params = 0
layers = 0
for name, param in baseline_model.named_parameters():
    layers+=1
    # print(name, param.size())
    if param.requires_grad : 
        trainable_params+=np.prod(param.size())

print(f"Trainable params : {trainable_params} and layers : {layers}")

Trainable params : 199210 and layers : 6
