In [1]:
import torch


In [7]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#import matplotlib.pyplot as plt


REBUILD_DATA = False

class DogsVSCats():
    IMG_SIZE = 50
    CATS ="PetImages/Cat"
    DOGS = "PetImages/Dog"
    LABELS = {CATS: 0, DOGS: 1}

    training_data = []
    catcount = 0
    dogcount = 0

    def make_training_data(self):
        for label in self.LABELS:
            print(label)
            for f in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, f)
                    img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img,(self.IMG_SIZE,self.IMG_SIZE))
                    self.training_data.append([np.array(img),np.eye(2)[self.LABELS[label]]])

                    if label == self.CATS:
                        self.catcount += 1
                    elif label == self.DOGS:
                        self.dogcount += 1

                except Exception as e:
                    print(str(e))

        np.random.shuffle(self.training_data)
        np.save("training_data.npy",self.training_data)
        print("Cats: ", self.catcount,"\n")
        print("Dogs: ",self.dogcount, "\n")

dogsVsCats = DogsVSCats()
if REBUILD_DATA:
    dogsVsCats.make_training_data()

training_data = np.load("training_data.npy",allow_pickle=True)
print(len(training_data))
print(training_data[0])

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,32,5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)

        x = torch.randn(50,50).view(-1,1,50,50)
        self._to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self._to_linear,512)
        self.fc2 = nn.Linear(512,2)

    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))

        print(x[0].shape)
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim = 1)

net = Net()

# show image for debug purposes
#img = plt.imshow(training_data[5][0],cmap="gray")
#plt.show()


optimizer = optim.Adam(net.parameters(),lr = 0.001)
loss_function = nn.MSELoss()

X = torch.Tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1

val_size = int(len(X)*VAL_PCT)
print(val_size)

train_X = X[:-val_size]
train_Y = y[:-val_size]

test_X = X[-val_size:]
test_Y = y[-val_size:]

print(len(train_X),len(test_X))

BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0,len(train_X),BATCH_SIZE)):
        #print(i, i + BATCH_SIZE)
        batch_x = train_X[i:i+BATCH_SIZE].view(-1,1,50,50)
        batch_y = train_Y[i:i+BATCH_SIZE]

        net.zero_grad()
        outputs = net(batch_x)
        loss = loss_function(outputs,batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch} Loss {loss}")


correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_Y[i])
        net_out = net(test_X[i].view(-1,1,50,50))[0]
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total += 1


print("Accuracy ", round(correct/total,3)*100)


24946
[array([[  9,  70,  72, ..., 104,  92, 159],
       [ 14,  50,  41, ..., 124,  97, 146],
       [  5,  36,  29, ..., 126,  96, 141],
       ...,
       [  9,  47,  96, ..., 146,  96, 142],
       [ 17,  65, 102, ..., 114,  87, 139],
       [ 11,  32,  38, ...,  42,  36,  65]], dtype=uint8)
 array([0., 1.])]
torch.Size([128, 2, 2])


  0%|▎                                                                                 | 1/225 [00:00<00:32,  6.86it/s]

2494
22452 2494
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  2%|█▍                                                                                | 4/225 [00:00<00:26,  8.24it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  3%|██▏                                                                               | 6/225 [00:00<00:24,  8.85it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  3%|██▌                                                                               | 7/225 [00:00<00:23,  9.12it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  4%|███▎                                                                              | 9/225 [00:00<00:22,  9.45it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  5%|███▉                                                                             | 11/225 [00:01<00:21,  9.73it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  6%|████▋                                                                            | 13/225 [00:01<00:21,  9.86it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  7%|█████▊                                                                           | 16/225 [00:01<00:20, 10.04it/s]


torch.Size([128, 2, 2])


  8%|██████                                                                           | 17/225 [00:01<00:20,  9.98it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  8%|██████▊                                                                          | 19/225 [00:01<00:20, 10.01it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

  9%|███████▌                                                                         | 21/225 [00:02<00:20, 10.13it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 10%|████████▎                                                                        | 23/225 [00:02<00:19, 10.38it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 11%|█████████                                                                        | 25/225 [00:02<00:19, 10.47it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 12%|█████████▋                                                                       | 27/225 [00:02<00:18, 10.51it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 13%|██████████▍                                                                      | 29/225 [00:02<00:18, 10.51it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 14%|███████████▏                                                                     | 31/225 [00:03<00:18, 10.42it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 15%|███████████▉                                                                     | 33/225 [00:03<00:18, 10.46it/s]


torch.Size([128, 2, 2])


 16%|████████████▌                                                                    | 35/225 [00:03<00:18, 10.39it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 16%|█████████████▎                                                                   | 37/225 [00:03<00:18, 10.23it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 17%|██████████████                                                                   | 39/225 [00:03<00:17, 10.36it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 18%|██████████████▊                                                                  | 41/225 [00:04<00:17, 10.29it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 19%|███████████████▍                                                                 | 43/225 [00:04<00:17, 10.27it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 20%|████████████████▏                                                                | 45/225 [00:04<00:17, 10.20it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 21%|████████████████▉                                                                | 47/225 [00:04<00:17, 10.19it/s]


torch.Size([128, 2, 2])


 22%|█████████████████▋                                                               | 49/225 [00:04<00:17, 10.31it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 23%|██████████████████▎                                                              | 51/225 [00:04<00:16, 10.40it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 24%|███████████████████                                                              | 53/225 [00:05<00:16, 10.55it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 24%|███████████████████▊                                                             | 55/225 [00:05<00:16, 10.42it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 25%|████████████████████▌                                                            | 57/225 [00:05<00:16, 10.47it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 26%|█████████████████████▏                                                           | 59/225 [00:05<00:15, 10.42it/s]


torch.Size([128, 2, 2])


 27%|█████████████████████▉                                                           | 61/225 [00:05<00:15, 10.38it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 28%|██████████████████████▋                                                          | 63/225 [00:06<00:15, 10.49it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 29%|███████████████████████▍                                                         | 65/225 [00:06<00:15, 10.37it/s]


torch.Size([128, 2, 2])


 30%|████████████████████████                                                         | 67/225 [00:06<00:15, 10.41it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 31%|████████████████████████▊                                                        | 69/225 [00:06<00:15, 10.18it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 32%|█████████████████████████▌                                                       | 71/225 [00:06<00:15, 10.24it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 32%|██████████████████████████▎                                                      | 73/225 [00:07<00:14, 10.22it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 33%|███████████████████████████                                                      | 75/225 [00:07<00:15,  9.76it/s]


torch.Size([128, 2, 2])


 34%|███████████████████████████▋                                                     | 77/225 [00:07<00:15,  9.80it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 35%|████████████████████████████▍                                                    | 79/225 [00:07<00:14,  9.94it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 36%|█████████████████████████████▏                                                   | 81/225 [00:07<00:14, 10.13it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 37%|█████████████████████████████▉                                                   | 83/225 [00:08<00:13, 10.19it/s]


torch.Size([128, 2, 2])


 38%|██████████████████████████████▌                                                  | 85/225 [00:08<00:14,  9.93it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 39%|███████████████████████████████▎                                                 | 87/225 [00:08<00:14,  9.84it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 40%|████████████████████████████████                                                 | 89/225 [00:08<00:13,  9.92it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 41%|█████████████████████████████████                                                | 92/225 [00:09<00:13,  9.87it/s]


torch.Size([128, 2, 2])


 42%|█████████████████████████████████▊                                               | 94/225 [00:09<00:13, 10.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 43%|██████████████████████████████████▌                                              | 96/225 [00:09<00:12, 10.12it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 44%|███████████████████████████████████▎                                             | 98/225 [00:09<00:12,  9.95it/s]


torch.Size([128, 2, 2])


 44%|███████████████████████████████████▌                                            | 100/225 [00:09<00:13,  9.45it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 45%|████████████████████████████████████▎                                           | 102/225 [00:10<00:12,  9.55it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 46%|████████████████████████████████████▌                                           | 103/225 [00:10<00:12,  9.63it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 47%|█████████████████████████████████████▎                                          | 105/225 [00:10<00:12,  9.76it/s]


torch.Size([128, 2, 2])


 48%|██████████████████████████████████████                                          | 107/225 [00:10<00:11,  9.91it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 48%|██████████████████████████████████████▊                                         | 109/225 [00:10<00:11,  9.80it/s]


torch.Size([128, 2, 2])


 49%|███████████████████████████████████████▍                                        | 111/225 [00:10<00:12,  9.29it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 50%|████████████████████████████████████████▏                                       | 113/225 [00:11<00:11,  9.50it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 52%|█████████████████████████████████████████▏                                      | 116/225 [00:11<00:11,  9.56it/s]


torch.Size([128, 2, 2])


 52%|█████████████████████████████████████████▌                                      | 117/225 [00:11<00:11,  9.64it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 53%|██████████████████████████████████████████▎                                     | 119/225 [00:11<00:10,  9.71it/s]


torch.Size([128, 2, 2])


 54%|███████████████████████████████████████████                                     | 121/225 [00:11<00:10,  9.82it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 55%|████████████████████████████████████████████                                    | 124/225 [00:12<00:10,  9.92it/s]


torch.Size([128, 2, 2])


 56%|████████████████████████████████████████████▊                                   | 126/225 [00:12<00:09, 10.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 57%|█████████████████████████████████████████████▌                                  | 128/225 [00:12<00:09,  9.99it/s]


torch.Size([128, 2, 2])


 58%|██████████████████████████████████████████████▏                                 | 130/225 [00:12<00:09,  9.76it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 59%|██████████████████████████████████████████████▉                                 | 132/225 [00:13<00:09,  9.81it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 60%|███████████████████████████████████████████████▋                                | 134/225 [00:13<00:09,  9.85it/s]


torch.Size([128, 2, 2])


 60%|████████████████████████████████████████████████▎                               | 136/225 [00:13<00:08, 10.05it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 61%|█████████████████████████████████████████████████                               | 138/225 [00:13<00:08, 10.13it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 62%|█████████████████████████████████████████████████▊                              | 140/225 [00:13<00:08, 10.16it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 63%|██████████████████████████████████████████████████▍                             | 142/225 [00:14<00:08, 10.18it/s]


torch.Size([128, 2, 2])


 64%|███████████████████████████████████████████████████▏                            | 144/225 [00:14<00:08, 10.03it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 65%|███████████████████████████████████████████████████▉                            | 146/225 [00:14<00:07,  9.97it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 66%|████████████████████████████████████████████████████▌                           | 148/225 [00:14<00:07, 10.11it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 67%|█████████████████████████████████████████████████████▎                          | 150/225 [00:14<00:07, 10.13it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 68%|██████████████████████████████████████████████████████                          | 152/225 [00:15<00:07, 10.11it/s]


torch.Size([128, 2, 2])


 68%|██████████████████████████████████████████████████████▊                         | 154/225 [00:15<00:07, 10.13it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 69%|███████████████████████████████████████████████████████▍                        | 156/225 [00:15<00:06, 10.09it/s]


torch.Size([128, 2, 2])


 70%|████████████████████████████████████████████████████████▏                       | 158/225 [00:15<00:06, 10.19it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 71%|████████████████████████████████████████████████████████▉                       | 160/225 [00:15<00:06, 10.22it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 72%|█████████████████████████████████████████████████████████▌                      | 162/225 [00:16<00:06, 10.27it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 73%|██████████████████████████████████████████████████████████▎                     | 164/225 [00:16<00:06, 10.13it/s]


torch.Size([128, 2, 2])


 74%|███████████████████████████████████████████████████████████                     | 166/225 [00:16<00:05, 10.22it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 75%|███████████████████████████████████████████████████████████▋                    | 168/225 [00:16<00:05, 10.18it/s]


torch.Size([128, 2, 2])


 76%|████████████████████████████████████████████████████████████▍                   | 170/225 [00:16<00:05, 10.13it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 76%|█████████████████████████████████████████████████████████████▏                  | 172/225 [00:17<00:05, 10.09it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 77%|█████████████████████████████████████████████████████████████▊                  | 174/225 [00:17<00:05, 10.11it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 78%|██████████████████████████████████████████████████████████████▌                 | 176/225 [00:17<00:04, 10.18it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 80%|███████████████████████████████████████████████████████████████▋                | 179/225 [00:17<00:04,  9.70it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 80%|████████████████████████████████████████████████████████████████▎               | 181/225 [00:17<00:04,  9.89it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 81%|█████████████████████████████████████████████████████████████████               | 183/225 [00:18<00:04, 10.10it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 82%|█████████████████████████████████████████████████████████████████▊              | 185/225 [00:18<00:03, 10.14it/s]


torch.Size([128, 2, 2])


 83%|██████████████████████████████████████████████████████████████████▍             | 187/225 [00:18<00:03, 10.10it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 84%|███████████████████████████████████████████████████████████████████▏            | 189/225 [00:18<00:03, 10.14it/s]


torch.Size([128, 2, 2])


 85%|███████████████████████████████████████████████████████████████████▉            | 191/225 [00:18<00:03, 10.17it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 86%|████████████████████████████████████████████████████████████████████▌           | 193/225 [00:19<00:03, 10.13it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 87%|█████████████████████████████████████████████████████████████████████▎          | 195/225 [00:19<00:02, 10.13it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 88%|██████████████████████████████████████████████████████████████████████          | 197/225 [00:19<00:02, 10.17it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 88%|██████████████████████████████████████████████████████████████████████▊         | 199/225 [00:19<00:02, 10.24it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 89%|███████████████████████████████████████████████████████████████████████▍        | 201/225 [00:19<00:02, 10.38it/s]


torch.Size([128, 2, 2])


 90%|████████████████████████████████████████████████████████████████████████▏       | 203/225 [00:20<00:02, 10.37it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 91%|████████████████████████████████████████████████████████████████████████▉       | 205/225 [00:20<00:01, 10.31it/s]


torch.Size([128, 2, 2])


 92%|█████████████████████████████████████████████████████████████████████████▌      | 207/225 [00:20<00:01, 10.07it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 93%|██████████████████████████████████████████████████████████████████████████▎     | 209/225 [00:20<00:01, 10.12it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 94%|███████████████████████████████████████████████████████████████████████████     | 211/225 [00:20<00:01, 10.29it/s]


torch.Size([128, 2, 2])


 95%|███████████████████████████████████████████████████████████████████████████▋    | 213/225 [00:21<00:01, 10.23it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 96%|████████████████████████████████████████████████████████████████████████████▍   | 215/225 [00:21<00:00, 10.15it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

 96%|█████████████████████████████████████████████████████████████████████████████▏  | 217/225 [00:21<00:00, 10.11it/s]


torch.Size([128, 2, 2])


 97%|█████████████████████████████████████████████████████████████████████████████▊  | 219/225 [00:21<00:00, 10.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 98%|██████████████████████████████████████████████████████████████████████████████▌ | 221/225 [00:21<00:00,  9.99it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 99%|███████████████████████████████████████████████████████████████████████████████▎| 223/225 [00:22<00:00, 10.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])

100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [00:22<00:00, 10.12it/s]
  3%|██▌                                                                            | 81/2494 [00:00<00:03, 804.12it/s]


tensor(0.2272, grad_fn=<MseLossBackward>)
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2

 10%|███████▋                                                                      | 247/2494 [00:00<00:02, 811.75it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 13%|██████████▎                                                                   | 330/2494 [00:00<00:02, 815.38it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 20%|███████████████▌                                                              | 496/2494 [00:00<00:02, 817.33it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 27%|████████████████████▋                                                         | 663/2494 [00:00<00:02, 819.53it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 33%|█████████████████████████▉                                                    | 828/2494 [00:01<00:02, 818.81it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 40%|███████████████████████████████▏                                              | 996/2494 [00:01<00:01, 826.44it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 47%|███████████████████████████████████▉                                         | 1162/2494 [00:01<00:01, 825.18it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 53%|█████████████████████████████████████████                                    | 1328/2494 [00:01<00:01, 825.40it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 57%|███████████████████████████████████████████▌                                 | 1412/2494 [00:01<00:01, 827.94it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 63%|████████████████████████████████████████████████▊                            | 1579/2494 [00:01<00:01, 827.15it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 70%|█████████████████████████████████████████████████████▉                       | 1745/2494 [00:02<00:00, 821.38it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 77%|██████████████████████████████████████████████████████████▉                  | 1910/2494 [00:02<00:00, 818.02it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 83%|████████████████████████████████████████████████████████████████             | 2076/2494 [00:02<00:00, 821.87it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 90%|█████████████████████████████████████████████████████████████████████▎       | 2243/2494 [00:02<00:00, 822.46it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

 97%|██████████████████████████████████████████████████████████████████████████▎  | 2408/2494 [00:02<00:00, 816.83it/s]


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128

100%|█████████████████████████████████████████████████████████████████████████████| 2494/2494 [00:03<00:00, 819.89it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,




In [8]:
torch.cuda.is_available()

True

In [9]:
device = torch.device("cuda:0")

In [10]:
device

device(type='cuda', index=0)

In [11]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("runnign on the GPU")
else:
    print("running on the CPU")
    device = torch.device("cpu")
    

runnign on the GPU


In [12]:
net.to(device) 

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=2, bias=True)
)

In [13]:
for epoch in range(EPOCHS):
    for i in tqdm(range(0,len(train_X),BATCH_SIZE)):
        #print(i, i + BATCH_SIZE)
        batch_x = train_X[i:i+BATCH_SIZE].view(-1,1,50,50)
        batch_y = train_Y[i:i+BATCH_SIZE]

        net.zero_grad()
        outputs = net(batch_x)
        loss = loss_function(outputs,batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch} Loss {loss}")
        

  0%|                                                                                                                                                                                                               | 0/225 [00:00<?, ?it/s]


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

In [15]:
x = torch.ones(5,5)
print(x.dtype)

torch.float32
