In [1]:
import json

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import v2

from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [2]:
import masks
from exp_base import get_accuracy, train, plot_probs, init_image

In [4]:
output = {
    "num_classes": [],
    "acc": [],
    "probs": [],
    "epochs": []
}

In [None]:
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

NUM_CLASSES = range(2, 65, 4)
BATCH_SIZE = 64  
IMAGE_SIZE = 64
EPOCHS = 5000
MASK = masks.MaskLine
# CRITERION = FocalLoss(gamma=1) 
CRITERION = nn.CrossEntropyLoss()

# writer = SummaryWriter(f"runs/experiment 1", comment=f"{BATCH_SIZE=}\n{IMAGE_SIZE=}\n{EPOCHS=}\n{MASK.__name__=}\n{CRITERION.__class__.__name__=}")
print(f"Running experiment 1")

for num_classes in NUM_CLASSES:
    univ_image = init_image(IMAGE_SIZE)
    
    print(f"---------- num_classes={num_classes} ----------")
    
    univ_image, epochs = train(
        model=model,
        univ_image=univ_image,
        epochs=EPOCHS,
        batch_size=num_classes,
        num_classes=num_classes,
        criterion=CRITERION,
        mask=MASK,
        writer=None
    )
    
    acc, probs = get_accuracy(model, univ_image, num_classes, MASK)
    
    output["num_classes"].append(num_classes)
    output["acc"].append(acc /  num_classes)
    output["probs"].append(probs)
    output["epochs"].append(epochs)

with open("result.json", "w") as f:
    json.dump(output, f)

Running experiment 1
---------- num_classes=2 ----------
EPOCH = 0/5000 | loss = 28.90737533569336 | acc=0.0 | lr=0.1 | time=0:0
CLOSING EPOCH = 72/5000 | loss = 0.823917806148529 | acc=1.0 | lr=0.1 | time=0:0
---------- num_classes=6 ----------
EPOCH = 0/5000 | loss = 31.735252380371094 | acc=0.0 | lr=0.1 | time=0:0
CLOSING EPOCH = 169/5000 | loss = 0.28455066680908203 | acc=1.0 | lr=0.1 | time=0:0
---------- num_classes=10 ----------
EPOCH = 0/5000 | loss = 33.864959716796875 | acc=0.0 | lr=0.1 | time=0:0
CLOSING EPOCH = 482/5000 | loss = 0.30100518465042114 | acc=1.0 | lr=0.1 | time=0:0
---------- num_classes=14 ----------
EPOCH = 0/5000 | loss = 27.021169662475586 | acc=0.0 | lr=0.1 | time=0:0
EPOCH = 500/5000 | loss = 0.4026038348674774 | acc=0.9285714285714286 | lr=0.1 | time=1:20
CLOSING EPOCH = 516/5000 | loss = 0.2962760031223297 | acc=1.0 | lr=0.1 | time=1:20
---------- num_classes=18 ----------
EPOCH = 0/5000 | loss = 32.07823181152344 | acc=0.0 | lr=0.1 | time=0:0
CLOSING E

In [7]:
print(output)

{'num_classes': [2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62], 'acc': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'probs': [[array(0.07589086, dtype=float32), array(0.9964557, dtype=float32)], [array(0.08404803, dtype=float32), array(0.9538117, dtype=float32), array(0.9398153, dtype=float32), array(0.9516147, dtype=float32), array(0.9814642, dtype=float32), array(0.9769743, dtype=float32)], [array(0.16387106, dtype=float32), array(0.94675624, dtype=float32), array(0.67618275, dtype=float32), array(0.5646349, dtype=float32), array(0.9842951, dtype=float32), array(0.6997943, dtype=float32), array(0.9909429, dtype=float32), array(0.08116138, dtype=float32), array(0.09859002, dtype=float32), array(0.6809992, dtype=float32)], [array(0.72104967, dtype=float32), array(0.877647, dtype=float32), array(0.9380749, dtype=float32), array(0.927086, dtype=float32), array(0.96320266, dtype=float32), array(0.9859242, dtype=float32), array(0.9900430

In [8]:
output["probs"]

[[array(0.07589086, dtype=float32), array(0.9964557, dtype=float32)],
 [array(0.08404803, dtype=float32),
  array(0.9538117, dtype=float32),
  array(0.9398153, dtype=float32),
  array(0.9516147, dtype=float32),
  array(0.9814642, dtype=float32),
  array(0.9769743, dtype=float32)],
 [array(0.16387106, dtype=float32),
  array(0.94675624, dtype=float32),
  array(0.67618275, dtype=float32),
  array(0.5646349, dtype=float32),
  array(0.9842951, dtype=float32),
  array(0.6997943, dtype=float32),
  array(0.9909429, dtype=float32),
  array(0.08116138, dtype=float32),
  array(0.09859002, dtype=float32),
  array(0.6809992, dtype=float32)],
 [array(0.72104967, dtype=float32),
  array(0.877647, dtype=float32),
  array(0.9380749, dtype=float32),
  array(0.927086, dtype=float32),
  array(0.96320266, dtype=float32),
  array(0.9859242, dtype=float32),
  array(0.99004304, dtype=float32),
  array(0.10618024, dtype=float32),
  array(0.06846671, dtype=float32),
  array(0.70305616, dtype=float32),
  array(

In [9]:
ans = []
for prob_array in output["probs"]:
    cur_ans = [float(x) for x in prob_array]
    ans.append(cur_ans)

In [10]:
ans

[[0.07589086145162582, 0.996455729007721],
 [0.08404802531003952,
  0.9538117051124573,
  0.9398152828216553,
  0.9516146779060364,
  0.9814642071723938,
  0.9769743084907532],
 [0.16387106478214264,
  0.9467562437057495,
  0.676182746887207,
  0.5646349191665649,
  0.984295129776001,
  0.6997942924499512,
  0.9909428954124451,
  0.08116137981414795,
  0.09859001636505127,
  0.680999219417572],
 [0.7210496664047241,
  0.8776469826698303,
  0.9380748867988586,
  0.9270859956741333,
  0.9632026553153992,
  0.9859241843223572,
  0.990043044090271,
  0.10618024319410324,
  0.06846670806407928,
  0.7030561566352844,
  0.8688160181045532,
  0.7111480236053467,
  0.7220498919487,
  0.8505417108535767],
 [0.05357353761792183,
  0.9913641214370728,
  0.8781933188438416,
  0.9803565144538879,
  0.9952436089515686,
  0.9033644795417786,
  0.9851232171058655,
  0.8652047514915466,
  0.6971597075462341,
  0.895591139793396,
  0.9548839926719666,
  0.9138321280479431,
  0.9789782166481018,
  0.91969

In [11]:
output["probs"] = ans

In [12]:
with open("./result.json", "w") as f:
    json.dump(output, f)