In [1]:
import json

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import v2

from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [2]:
import masks
from exp_base import get_accuracy, train, plot_probs, init_image

In [3]:
output = {
    "num_classes": [],
    "acc": [],
    "probs": [],
    "epochs": []
}

In [4]:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

NUM_CLASSES = range(2, 65)
BATCH_SIZE = 64  
IMAGE_SIZE = 64
EPOCHS = 5000
MASK = masks.MaskImageAll
# CRITERION = FocalLoss(gamma=1) 
CRITERION = nn.CrossEntropyLoss()

# writer = SummaryWriter(f"runs/experiment 1", comment=f"{BATCH_SIZE=}\n{IMAGE_SIZE=}\n{EPOCHS=}\n{MASK.__name__=}\n{CRITERION.__class__.__name__=}")
print(f"Running experiment 1")

for num_classes in NUM_CLASSES:
    univ_image = init_image(IMAGE_SIZE)
    
    print(f"---------- num_classes={num_classes} ----------")
    
    univ_image, epochs = train(
        model=model,
        univ_image=univ_image,
        epochs=EPOCHS,
        batch_size=num_classes,
        num_classes=num_classes,
        criterion=CRITERION,
        mask=MASK,
        writer=None
    )
    
    acc, probs = get_accuracy(model, univ_image, num_classes, MASK)
    
    output["num_classes"].append(num_classes)
    output["acc"].append(acc /  num_classes)
    output["probs"].append(probs)
    output["epochs"].append(epochs)

with open("result.json", "w") as f:
    json.dump(output, f)

Running experiment 1
---------- num_classes=2 ----------
EPOCH = 0/5000 | loss = 19.175540924072266 | acc=0.0 | lr=0.1 | time=0:0
EPOCH = 500/5000 | loss = 2.857185125350952 | acc=0.5 | lr=0.1 | time=0:18
EPOCH = 1000/5000 | loss = 2.7755274772644043 | acc=0.5 | lr=0.1 | time=0:18
EPOCH = 1500/5000 | loss = 2.675570011138916 | acc=0.5 | lr=0.1 | time=0:18
EPOCH = 2000/5000 | loss = 2.616549253463745 | acc=0.5 | lr=0.1 | time=0:17
EPOCH = 2500/5000 | loss = 2.0676348209381104 | acc=0.5 | lr=0.1 | time=0:17
CLOSING EPOCH = 2505/5000 | loss = 1.3691562414169312 | acc=1.0 | lr=0.1 | time=0:17
---------- num_classes=3 ----------
EPOCH = 0/5000 | loss = 11.01796817779541 | acc=0.0 | lr=0.1 | time=0:0
CLOSING EPOCH = 73/5000 | loss = 1.1550992727279663 | acc=1.0 | lr=0.1 | time=0:0
---------- num_classes=4 ----------
EPOCH = 0/5000 | loss = 10.397782325744629 | acc=0.0 | lr=0.1 | time=0:0
CLOSING EPOCH = 76/5000 | loss = 0.7052490711212158 | acc=1.0 | lr=0.1 | time=0:0
---------- num_classes=

TypeError: Object of type ndarray is not JSON serializable

In [5]:
print(output)

{'num_classes': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64], 'acc': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9629629629629629, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9761904761904762, 1.0, 1.0, 1.0, 0.9347826086956522, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9807692307692307, 0.9433962264150944, 0.9074074074074074, 0.8727272727272727, 1.0, 0.9298245614035088, 1.0, 0.9152542372881356, 1.0, 0.8852459016393442, 0.9193548387096774, 0.9206349206349206, 0.9375], 'probs': [[array(0.07316099, dtype=float32), array(0.9987956, dtype=float32)], [array(0.07864078, dtype=float32), array(0.8153608, dtype=float32), array(0.6689917, dtype=float32)], [array(0.13852656, dtype=float32), array(0.901497

In [9]:
output["probs"]

[[array(0.07316099, dtype=float32), array(0.9987956, dtype=float32)],
 [array(0.07864078, dtype=float32),
  array(0.8153608, dtype=float32),
  array(0.6689917, dtype=float32)],
 [array(0.13852656, dtype=float32),
  array(0.901497, dtype=float32),
  array(0.80306834, dtype=float32),
  array(0.82563114, dtype=float32)],
 [array(0.06026242, dtype=float32),
  array(0.92808753, dtype=float32),
  array(0.64179057, dtype=float32),
  array(0.611583, dtype=float32),
  array(0.23703682, dtype=float32)],
 [array(0.04337703, dtype=float32),
  array(0.9812339, dtype=float32),
  array(0.59612966, dtype=float32),
  array(0.85854316, dtype=float32),
  array(0.9128962, dtype=float32),
  array(0.86954856, dtype=float32)],
 [array(0.07645092, dtype=float32),
  array(0.96087843, dtype=float32),
  array(0.9571215, dtype=float32),
  array(0.91717476, dtype=float32),
  array(0.9540518, dtype=float32),
  array(0.9341437, dtype=float32),
  array(0.96823746, dtype=float32)],
 [array(0.06638242, dtype=float32),


In [15]:
ans = []
for prob_array in output["probs"]:
    cur_ans = [float(x) for x in prob_array]
    ans.append(cur_ans)

In [16]:
ans

[[0.07316099107265472, 0.9987956285476685],
 [0.07864078134298325, 0.8153607845306396, 0.6689916849136353],
 [0.1385265588760376,
  0.9014970064163208,
  0.8030683398246765,
  0.8256311416625977],
 [0.060262419283390045,
  0.9280875325202942,
  0.6417905688285828,
  0.6115829944610596,
  0.2370368242263794],
 [0.04337703064084053,
  0.9812338948249817,
  0.5961296558380127,
  0.8585431575775146,
  0.9128962159156799,
  0.8695485591888428],
 [0.07645092159509659,
  0.9608784317970276,
  0.9571214914321899,
  0.917174756526947,
  0.9540517926216125,
  0.9341437220573425,
  0.9682374596595764],
 [0.06638241559267044,
  0.4993000328540802,
  0.846018373966217,
  0.535443127155304,
  0.9643420577049255,
  0.4212600886821747,
  0.943128228187561,
  0.9366093873977661],
 [0.2714327573776245,
  0.5555797219276428,
  0.5217588543891907,
  0.6000574231147766,
  0.9660618901252747,
  0.6780185103416443,
  0.9615605473518372,
  0.10607416182756424,
  0.04645666107535362],
 [0.09072411060333252,
  

In [17]:
output["probs"] = ans

In [19]:
with open("./result.json", "w") as f:
    json.dump(output, f)