<a href="https://colab.research.google.com/github/BaBa0525/machine-learning/blob/main/HW5/hw5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import package

In [None]:
import csv
import cv2
import numpy as np
import random
import os

from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/ML-HW5
!ls

/content/drive/MyDrive/Colab Notebooks/ML-HW5
captcha-hacker.zip  hw5.ipynb  input  submission.csv


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## Check dataset

In [None]:
# import os

# for dirname, _, filenames in os.walk('./input'):
#     for filename in filenames[:3]:
#         print(os.path.join(dirname, filename))
#     if len(filenames) > 3:
#         print("...")


In [None]:
TRAIN_PATH = "./input/train"
TEST_PATH = "./input/test"

## Dataset Class

In [None]:
class Task1Dataset(Dataset):
  def __init__(self, data, root, return_filename=False):
    self.data = [sample for sample in data if sample[0].startswith("task1")]
    self.return_filename = return_filename
    self.root = root

  def __getitem__(self, index):
    filename, label = self.data[index]
    img = cv2.imread(f"{self.root}/{filename}")
    img = cv2.resize(img, (32, 32))
    img = np.mean(img, axis=2)
    if self.return_filename:
      return torch.FloatTensor((img - 128) / 128), filename

    return torch.FloatTensor((img - 128) / 128), int(label)

  def __len__(self):
    return len(self.data)
    

## Model Class

In [None]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5),
        nn.LeakyReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(in_channels=3, out_channels=9, kernel_size=5),
        nn.LeakyReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(9 * 5 * 5, 100),
        nn.LeakyReLU(),
        nn.Linear(100, 50),
        nn.LeakyReLU(),
        nn.Linear(50, 10)
    )

  def forward(self, x):
    # print(f"{x.shape=}")
    # b, h, w = x.shape
    # x = x.view(b, h*w)
    x = x[:, np.newaxis, :]
    print(f"{x.shape=}")
    return self.layers(x)

## Process training data

In [None]:
train_data = []
val_data = []

with open(f"{TRAIN_PATH}/annotations.csv", newline="") as csvfile:
  for row in csv.reader(csvfile, delimiter=","):
    if random.random() < 0.7:
      train_data.append(row)
    else:
      val_data.append(row)

train_ds = Task1Dataset(train_data, TRAIN_PATH)
train_dl = DataLoader(train_ds, batch_size=500, drop_last=True, shuffle=True)

val_ds = Task1Dataset(val_data, TRAIN_PATH)
val_dl = DataLoader(val_ds, batch_size=500, drop_last=True, shuffle=True)

## Training

In [None]:
model = Model().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in tqdm(range(100), leave=True):
  print(f"Epoch: [{epoch}]")

  model.train()
  for image, label in train_dl:
    print("---")
    image = image.to(device)
    label = label.to(device)

    pred = model(image)
    loss = loss_fn(pred, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  sample_count = 0
  correct_count = 0

  model.eval()
  for image, label in val_dl:
    image = image.to(device)
    label = label.to(device)

    pred = model(image)
    loss = loss_fn(pred, label)

    pred = torch.argmax(pred, dim=1)

    sample_count += len(image)
    correct_count += (label == pred).sum()

  print(f"accuracy(validation): {correct_count / sample_count}")



  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: [0]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  1%|          | 1/100 [05:54<9:44:31, 354.26s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11800000816583633
Epoch: [1]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  2%|▏         | 2/100 [07:18<5:18:51, 195.22s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11000000685453415
Epoch: [2]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  3%|▎         | 3/100 [07:44<3:11:15, 118.30s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11400000751018524
Epoch: [3]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  4%|▍         | 4/100 [07:54<2:00:45, 75.47s/it] 

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11800000816583633
Epoch: [4]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  5%|▌         | 5/100 [07:59<1:19:06, 49.96s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11600000411272049
Epoch: [5]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  6%|▌         | 6/100 [08:02<53:28, 34.13s/it]  

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11600000411272049
Epoch: [6]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  7%|▋         | 7/100 [08:06<37:13, 24.02s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.11400000751018524
Epoch: [7]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  8%|▊         | 8/100 [08:08<26:25, 17.23s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.12400000542402267
Epoch: [8]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


  9%|▉         | 9/100 [08:11<19:14, 12.68s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.12600000202655792
Epoch: [9]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 10%|█         | 10/100 [08:14<14:23,  9.60s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.14800000190734863
Epoch: [10]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 11%|█         | 11/100 [08:16<11:09,  7.53s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.17000000178813934
Epoch: [11]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 12%|█▏        | 12/100 [08:19<08:54,  6.07s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.17400000989437103
Epoch: [12]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 13%|█▎        | 13/100 [08:22<07:19,  5.06s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.15200001001358032
Epoch: [13]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 14%|█▍        | 14/100 [08:25<06:15,  4.37s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.164000004529953
Epoch: [14]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 15%|█▌        | 15/100 [08:28<05:31,  3.89s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.1600000113248825
Epoch: [15]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 16%|█▌        | 16/100 [08:30<04:58,  3.55s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.18400001525878906
Epoch: [16]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 17%|█▋        | 17/100 [08:33<04:34,  3.31s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.18800000846385956
Epoch: [17]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 18%|█▊        | 18/100 [08:36<04:17,  3.14s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.1940000057220459
Epoch: [18]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 19%|█▉        | 19/100 [08:39<04:05,  3.03s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.25
Epoch: [19]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 20%|██        | 20/100 [08:41<03:56,  2.96s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.24000000953674316
Epoch: [20]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 21%|██        | 21/100 [08:44<03:48,  2.89s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.23000000417232513
Epoch: [21]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 22%|██▏       | 22/100 [08:47<03:42,  2.85s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.2540000081062317
Epoch: [22]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 23%|██▎       | 23/100 [08:50<03:37,  2.82s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.27000001072883606
Epoch: [23]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 24%|██▍       | 24/100 [08:52<03:32,  2.80s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.29600000381469727
Epoch: [24]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 25%|██▌       | 25/100 [08:55<03:29,  2.80s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.31200000643730164
Epoch: [25]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 26%|██▌       | 26/100 [08:58<03:25,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.3240000009536743
Epoch: [26]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 27%|██▋       | 27/100 [09:01<03:21,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.32200002670288086
Epoch: [27]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 28%|██▊       | 28/100 [09:03<03:18,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.33800002932548523
Epoch: [28]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 29%|██▉       | 29/100 [09:06<03:17,  2.78s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.3660000264644623
Epoch: [29]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 30%|███       | 30/100 [09:09<03:12,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.4020000100135803
Epoch: [30]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 31%|███       | 31/100 [09:12<03:09,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.4140000343322754
Epoch: [31]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 32%|███▏      | 32/100 [09:14<03:07,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.39000001549720764
Epoch: [32]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 33%|███▎      | 33/100 [09:17<03:05,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.4300000071525574
Epoch: [33]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 34%|███▍      | 34/100 [09:20<03:02,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.47200003266334534
Epoch: [34]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 35%|███▌      | 35/100 [09:23<03:00,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.46400001645088196
Epoch: [35]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 36%|███▌      | 36/100 [09:25<02:57,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.49400001764297485
Epoch: [36]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 37%|███▋      | 37/100 [09:28<02:54,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5200000405311584
Epoch: [37]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 38%|███▊      | 38/100 [09:31<02:51,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5360000133514404
Epoch: [38]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 39%|███▉      | 39/100 [09:34<02:48,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5540000200271606
Epoch: [39]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 40%|████      | 40/100 [09:36<02:45,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5420000553131104
Epoch: [40]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 41%|████      | 41/100 [09:39<02:43,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5520000457763672
Epoch: [41]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 42%|████▏     | 42/100 [09:42<02:39,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.550000011920929
Epoch: [42]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 43%|████▎     | 43/100 [09:45<02:37,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5600000023841858
Epoch: [43]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 44%|████▍     | 44/100 [09:47<02:34,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5879999995231628
Epoch: [44]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 45%|████▌     | 45/100 [09:50<02:32,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5960000157356262
Epoch: [45]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 46%|████▌     | 46/100 [09:53<02:32,  2.83s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.5940000414848328
Epoch: [46]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 47%|████▋     | 47/100 [09:56<02:28,  2.81s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6100000143051147
Epoch: [47]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 48%|████▊     | 48/100 [09:59<02:27,  2.83s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6320000290870667
Epoch: [48]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 49%|████▉     | 49/100 [10:02<02:23,  2.82s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6200000047683716
Epoch: [49]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 50%|█████     | 50/100 [10:05<02:21,  2.83s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6360000371932983
Epoch: [50]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 51%|█████     | 51/100 [10:07<02:18,  2.83s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6600000262260437
Epoch: [51]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 52%|█████▏    | 52/100 [10:10<02:15,  2.83s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6620000600814819
Epoch: [52]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 53%|█████▎    | 53/100 [10:13<02:11,  2.81s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6540000438690186
Epoch: [53]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 54%|█████▍    | 54/100 [10:16<02:08,  2.79s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6820000410079956
Epoch: [54]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 55%|█████▌    | 55/100 [10:18<02:05,  2.78s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6820000410079956
Epoch: [55]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 56%|█████▌    | 56/100 [10:21<02:02,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6520000100135803
Epoch: [56]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 57%|█████▋    | 57/100 [10:24<01:58,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.6740000247955322
Epoch: [57]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 58%|█████▊    | 58/100 [10:27<01:56,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7020000219345093
Epoch: [58]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 59%|█████▉    | 59/100 [10:30<01:53,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7120000123977661
Epoch: [59]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 60%|██████    | 60/100 [10:32<01:50,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7120000123977661
Epoch: [60]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 61%|██████    | 61/100 [10:35<01:47,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.734000027179718
Epoch: [61]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 62%|██████▏   | 62/100 [10:38<01:44,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7240000367164612
Epoch: [62]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 63%|██████▎   | 63/100 [10:41<01:42,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7240000367164612
Epoch: [63]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 64%|██████▍   | 64/100 [10:43<01:38,  2.74s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.734000027179718
Epoch: [64]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 65%|██████▌   | 65/100 [10:46<01:36,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7420000433921814
Epoch: [65]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 66%|██████▌   | 66/100 [10:49<01:34,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7460000514984131
Epoch: [66]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 67%|██████▋   | 67/100 [10:52<01:31,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7460000514984131
Epoch: [67]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 68%|██████▊   | 68/100 [10:54<01:28,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7440000176429749
Epoch: [68]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 69%|██████▉   | 69/100 [10:57<01:25,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7660000324249268
Epoch: [69]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 70%|███████   | 70/100 [11:00<01:22,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7420000433921814
Epoch: [70]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 71%|███████   | 71/100 [11:03<01:19,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7640000581741333
Epoch: [71]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 72%|███████▏  | 72/100 [11:05<01:17,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7900000214576721
Epoch: [72]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 73%|███████▎  | 73/100 [11:08<01:14,  2.74s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7640000581741333
Epoch: [73]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 74%|███████▍  | 74/100 [11:11<01:11,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7720000147819519
Epoch: [74]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 75%|███████▌  | 75/100 [11:14<01:08,  2.73s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7780000567436218
Epoch: [75]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 76%|███████▌  | 76/100 [11:16<01:05,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7900000214576721
Epoch: [76]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 77%|███████▋  | 77/100 [11:19<01:03,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7720000147819519
Epoch: [77]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 78%|███████▊  | 78/100 [11:22<01:00,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7780000567436218
Epoch: [78]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 79%|███████▉  | 79/100 [11:25<00:57,  2.74s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7980000376701355
Epoch: [79]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 80%|████████  | 80/100 [11:27<00:55,  2.75s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8020000457763672
Epoch: [80]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 81%|████████  | 81/100 [11:30<00:52,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8040000200271606
Epoch: [81]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 82%|████████▏ | 82/100 [11:33<00:49,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8100000619888306
Epoch: [82]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 83%|████████▎ | 83/100 [11:36<00:46,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.796000063419342
Epoch: [83]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 84%|████████▍ | 84/100 [11:38<00:44,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7940000295639038
Epoch: [84]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 85%|████████▌ | 85/100 [11:41<00:41,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.800000011920929
Epoch: [85]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 86%|████████▌ | 86/100 [11:44<00:38,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8040000200271606
Epoch: [86]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 87%|████████▋ | 87/100 [11:47<00:36,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.7900000214576721
Epoch: [87]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 88%|████████▊ | 88/100 [11:49<00:33,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8080000281333923
Epoch: [88]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 89%|████████▉ | 89/100 [11:52<00:30,  2.79s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8300000429153442
Epoch: [89]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 90%|█████████ | 90/100 [11:55<00:27,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8160000443458557
Epoch: [90]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 91%|█████████ | 91/100 [11:58<00:24,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8080000281333923
Epoch: [91]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 92%|█████████▏| 92/100 [12:01<00:22,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.812000036239624
Epoch: [92]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 93%|█████████▎| 93/100 [12:04<00:19,  2.85s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8380000591278076
Epoch: [93]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 94%|█████████▍| 94/100 [12:06<00:16,  2.82s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8160000443458557
Epoch: [94]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 95%|█████████▌| 95/100 [12:09<00:14,  2.80s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8140000104904175
Epoch: [95]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 96%|█████████▌| 96/100 [12:12<00:11,  2.79s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8400000333786011
Epoch: [96]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 97%|█████████▋| 97/100 [12:15<00:08,  2.79s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8400000333786011
Epoch: [97]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 98%|█████████▊| 98/100 [12:17<00:05,  2.77s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8220000267028809
Epoch: [98]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


 99%|█████████▉| 99/100 [12:20<00:02,  2.76s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8360000252723694
Epoch: [99]
---
x.shape=torch.Size([500, 1, 32, 32])
---
x.shape=torch.Size([500, 1, 32, 32])


100%|██████████| 100/100 [12:23<00:00,  7.43s/it]

x.shape=torch.Size([500, 1, 32, 32])
accuracy(validation): 0.8460000157356262





In [None]:
# torch.save(model.state_dict(), "task1_model.pt")
print(model.state_dict())


OrderedDict([('layers.0.weight', tensor([[[[-0.0899, -0.0271, -0.0276,  0.0023, -0.0878],
          [-0.0096, -0.0608,  0.1565, -0.1445, -0.0741],
          [-0.0163, -0.0309,  0.0682,  0.0362, -0.0115],
          [-0.0362,  0.0962,  0.2872,  0.1128, -0.1070],
          [ 0.2070,  0.3048,  0.2565,  0.2978,  0.2694]]],


        [[[ 0.0765, -0.0575, -0.1304, -0.0434, -0.1573],
          [-0.0727, -0.0610,  0.0031, -0.2441, -0.2059],
          [ 0.0438, -0.0996,  0.2365, -0.0499, -0.3081],
          [ 0.1063,  0.0428,  0.0502,  0.2179,  0.0092],
          [-0.1050,  0.3127,  0.3362,  0.0311, -0.2202]]],


        [[[ 0.0103, -0.1601, -0.0670,  0.0253,  0.3175],
          [ 0.0317, -0.2246, -0.0945, -0.1124,  0.1035],
          [ 0.2479, -0.0798, -0.2734,  0.1377,  0.1286],
          [ 0.0254, -0.1637, -0.1532,  0.0660, -0.0410],
          [-0.0656,  0.1607, -0.1212, -0.2007,  0.0443]]]], device='cuda:0')), ('layers.0.bias', tensor([-0.0487,  0.0451,  0.1996], device='cuda:0')), ('layers.

## Output prediction

In [None]:
test_data = []

with open(f"{TEST_PATH}/../sample_submission.csv", newline="") as csvfile:
  for row in csv.reader(csvfile, delimiter=","):
    test_data.append(row)

test_ds = Task1Dataset(test_data, TEST_PATH, return_filename=True)
test_dl = DataLoader(test_ds, batch_size=500, num_workers=2, drop_last=False, shuffle=False)

if os.path.exists("submission.csv"):
  fo = open("submission.csv", mode="a", newline="")
  csv_writer = csv.writer(fo)
else:
  fo = open("submission.csv", mode="a", newline="")
  csv_writer = csv.writer(fo)
  csv_writer.writerow(["filename", "label"])

model.eval()
for image, filenames in test_dl:
  image = image.to(device)

  pred = model(image)
  pred = torch.argmax(pred, dim=1)

  for ind, filename in enumerate(filenames):
    if (len(filename) < 3):
      print(filename)
    csv_writer.writerow([filename, str(pred[ind].item())])



x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])
x.shape=torch.Size([500, 1, 32, 32])


In [None]:
for filename, _ in test_data:
  if filename.startswith("task2") or filename.startswith("task3"):
    # print(filename)
    csv_writer.writerow([filename, 0])

fo.close()

In [None]:
!ls

captcha-hacker.zip  hw5.ipynb  input  submission.csv
