In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class DigitDataset(Dataset):
    def __init__(self, data_dir, transform=None, max_boxes=10):
        self.data_dir = data_dir
        self.transform = transform
        self.images = [img for img in os.listdir(data_dir) if img.endswith('.png')]
        self.max_boxes = max_boxes
        self.annotations = self._load_annotations()

    def _load_annotations(self):
        annotations = {}
        for img in self.images:
            annotation_path = os.path.join(self.data_dir, img.replace('.png', '.txt'))
            if os.path.exists(annotation_path):
                with open(annotation_path, 'r') as f:
                    boxes = []
                    for line in f:
                        boxes.append([float(x) for x in line.strip().split()])
                    annotations[img] = boxes
        return annotations

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        targets = self.annotations.get(img_name, [])
        # Pad targets to fixed size
        padded_targets = np.zeros((self.max_boxes, 5))
        if len(targets) > 0:
            targets = np.array(targets)
            padded_targets[:len(targets), :5] = targets

        padded_targets = torch.tensor(padded_targets, dtype=torch.float32)
        return image, padded_targets

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

class SmallResNet(nn.Module):
    def __init__(self, num_boxes=10):
        super(SmallResNet, self).__init__()
        self.num_boxes = num_boxes
        self.model = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to accept 32x32 images
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()  # Remove the first maxpool layer
        self.model.fc = nn.Linear(self.model.fc.in_features, num_boxes * 5)  # 5 predictions (class + 4 coordinates) per box

    def forward(self, x):
        return self.model(x).view(-1, self.num_boxes, 5)

model = SmallResNet().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, criterion, optimizer, dataloader, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            if i % 10 == 0:  # Display loss every 10 batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {epoch_loss:.4f}')

train_dataset = DigitDataset(data_dir='HW6_train', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

train_model(model, criterion, optimizer, train_loader, num_epochs=20)

def predict(model, dataloader):
    model.eval()
    results = []
    with torch.no_grad():
        for idx, (image, _) in enumerate(dataloader):
            image = image.to(device)
            outputs = model(image)
            outputs = outputs.cpu().numpy()
            for i, output in enumerate(outputs):
                for box in output:
                    cls = int(box[0])
                    x = box[1]
                    y = box[2]
                    results.append([idx, cls, x, y])
    return results

test_dataset = DigitDataset(data_dir='HW6_test', transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

predictions = predict(model, test_loader)

# Convert predictions to dataframe
df = pd.DataFrame(predictions, columns=['IMAGE_index', 'class', 'x', 'y'])
df.to_csv('predictions.csv', index=False)




Epoch [1/20], Batch [1/3000], Loss: 2.5314
Epoch [1/20], Batch [11/3000], Loss: 0.6009
Epoch [1/20], Batch [21/3000], Loss: 0.7004
Epoch [1/20], Batch [31/3000], Loss: 0.5076
Epoch [1/20], Batch [41/3000], Loss: 0.6256
Epoch [1/20], Batch [51/3000], Loss: 0.5802
Epoch [1/20], Batch [61/3000], Loss: 0.4676
Epoch [1/20], Batch [71/3000], Loss: 0.5023
Epoch [1/20], Batch [81/3000], Loss: 0.4553
Epoch [1/20], Batch [91/3000], Loss: 0.4810
Epoch [1/20], Batch [101/3000], Loss: 0.4509
Epoch [1/20], Batch [111/3000], Loss: 0.4664
Epoch [1/20], Batch [121/3000], Loss: 0.3765
Epoch [1/20], Batch [131/3000], Loss: 0.4829
Epoch [1/20], Batch [141/3000], Loss: 0.4884
Epoch [1/20], Batch [151/3000], Loss: 0.4624
Epoch [1/20], Batch [161/3000], Loss: 0.4929
Epoch [1/20], Batch [171/3000], Loss: 0.4286
Epoch [1/20], Batch [181/3000], Loss: 0.4778
Epoch [1/20], Batch [191/3000], Loss: 0.3923
Epoch [1/20], Batch [201/3000], Loss: 0.3833
Epoch [1/20], Batch [211/3000], Loss: 0.3962
Epoch [1/20], Batch [

Epoch [1/20], Batch [1811/3000], Loss: 0.3162
Epoch [1/20], Batch [1821/3000], Loss: 0.3647
Epoch [1/20], Batch [1831/3000], Loss: 0.3598
Epoch [1/20], Batch [1841/3000], Loss: 0.3783
Epoch [1/20], Batch [1851/3000], Loss: 0.3888
Epoch [1/20], Batch [1861/3000], Loss: 0.4614
Epoch [1/20], Batch [1871/3000], Loss: 0.5056
Epoch [1/20], Batch [1881/3000], Loss: 0.4434
Epoch [1/20], Batch [1891/3000], Loss: 0.3504
Epoch [1/20], Batch [1901/3000], Loss: 0.3005
Epoch [1/20], Batch [1911/3000], Loss: 0.3661
Epoch [1/20], Batch [1921/3000], Loss: 0.4163
Epoch [1/20], Batch [1931/3000], Loss: 0.3850
Epoch [1/20], Batch [1941/3000], Loss: 0.3052
Epoch [1/20], Batch [1951/3000], Loss: 0.3544
Epoch [1/20], Batch [1961/3000], Loss: 0.3837
Epoch [1/20], Batch [1971/3000], Loss: 0.3074
Epoch [1/20], Batch [1981/3000], Loss: 0.4841
Epoch [1/20], Batch [1991/3000], Loss: 0.4594
Epoch [1/20], Batch [2001/3000], Loss: 0.5451
Epoch [1/20], Batch [2011/3000], Loss: 0.2917
Epoch [1/20], Batch [2021/3000], L

Epoch [2/20], Batch [601/3000], Loss: 0.2710
Epoch [2/20], Batch [611/3000], Loss: 0.3955
Epoch [2/20], Batch [621/3000], Loss: 0.3125
Epoch [2/20], Batch [631/3000], Loss: 0.4318
Epoch [2/20], Batch [641/3000], Loss: 0.3624
Epoch [2/20], Batch [651/3000], Loss: 0.3541
Epoch [2/20], Batch [661/3000], Loss: 0.3367
Epoch [2/20], Batch [671/3000], Loss: 0.3454
Epoch [2/20], Batch [681/3000], Loss: 0.4420
Epoch [2/20], Batch [691/3000], Loss: 0.2866
Epoch [2/20], Batch [701/3000], Loss: 0.3780
Epoch [2/20], Batch [711/3000], Loss: 0.3660
Epoch [2/20], Batch [721/3000], Loss: 0.2926
Epoch [2/20], Batch [731/3000], Loss: 0.3897
Epoch [2/20], Batch [741/3000], Loss: 0.4113
Epoch [2/20], Batch [751/3000], Loss: 0.3562
Epoch [2/20], Batch [761/3000], Loss: 0.3137
Epoch [2/20], Batch [771/3000], Loss: 0.4076
Epoch [2/20], Batch [781/3000], Loss: 0.3085
Epoch [2/20], Batch [791/3000], Loss: 0.4363
Epoch [2/20], Batch [801/3000], Loss: 0.4279
Epoch [2/20], Batch [811/3000], Loss: 0.4068
Epoch [2/2

Epoch [2/20], Batch [2391/3000], Loss: 0.4320
Epoch [2/20], Batch [2401/3000], Loss: 0.5048
Epoch [2/20], Batch [2411/3000], Loss: 0.3352
Epoch [2/20], Batch [2421/3000], Loss: 0.3568
Epoch [2/20], Batch [2431/3000], Loss: 0.3401
Epoch [2/20], Batch [2441/3000], Loss: 0.3887
Epoch [2/20], Batch [2451/3000], Loss: 0.4477
Epoch [2/20], Batch [2461/3000], Loss: 0.2649
Epoch [2/20], Batch [2471/3000], Loss: 0.5126
Epoch [2/20], Batch [2481/3000], Loss: 0.3491
Epoch [2/20], Batch [2491/3000], Loss: 0.4435
Epoch [2/20], Batch [2501/3000], Loss: 0.4651
Epoch [2/20], Batch [2511/3000], Loss: 0.4538
Epoch [2/20], Batch [2521/3000], Loss: 0.3916
Epoch [2/20], Batch [2531/3000], Loss: 0.3560
Epoch [2/20], Batch [2541/3000], Loss: 0.5109
Epoch [2/20], Batch [2551/3000], Loss: 0.3896
Epoch [2/20], Batch [2561/3000], Loss: 0.3154
Epoch [2/20], Batch [2571/3000], Loss: 0.4468
Epoch [2/20], Batch [2581/3000], Loss: 0.4210
Epoch [2/20], Batch [2591/3000], Loss: 0.4107
Epoch [2/20], Batch [2601/3000], L

Epoch [3/20], Batch [1191/3000], Loss: 0.4971
Epoch [3/20], Batch [1201/3000], Loss: 0.3317
Epoch [3/20], Batch [1211/3000], Loss: 0.4721
Epoch [3/20], Batch [1221/3000], Loss: 0.3097
Epoch [3/20], Batch [1231/3000], Loss: 0.4756
Epoch [3/20], Batch [1241/3000], Loss: 0.3486
Epoch [3/20], Batch [1251/3000], Loss: 0.3677
Epoch [3/20], Batch [1261/3000], Loss: 0.4145
Epoch [3/20], Batch [1271/3000], Loss: 0.4088
Epoch [3/20], Batch [1281/3000], Loss: 0.3801
Epoch [3/20], Batch [1291/3000], Loss: 0.3518
Epoch [3/20], Batch [1301/3000], Loss: 0.3656
Epoch [3/20], Batch [1311/3000], Loss: 0.2667
Epoch [3/20], Batch [1321/3000], Loss: 0.4693
Epoch [3/20], Batch [1331/3000], Loss: 0.3675
Epoch [3/20], Batch [1341/3000], Loss: 0.4438
Epoch [3/20], Batch [1351/3000], Loss: 0.2976
Epoch [3/20], Batch [1361/3000], Loss: 0.3508
Epoch [3/20], Batch [1371/3000], Loss: 0.3420
Epoch [3/20], Batch [1381/3000], Loss: 0.3442
Epoch [3/20], Batch [1391/3000], Loss: 0.3698
Epoch [3/20], Batch [1401/3000], L

Epoch [3/20], Batch [2981/3000], Loss: 0.2778
Epoch [3/20], Batch [2991/3000], Loss: 0.3885
Epoch 3/20, Average Loss: 0.3813
Epoch [4/20], Batch [1/3000], Loss: 0.4125
Epoch [4/20], Batch [11/3000], Loss: 0.3629
Epoch [4/20], Batch [21/3000], Loss: 0.3299
Epoch [4/20], Batch [31/3000], Loss: 0.3278
Epoch [4/20], Batch [41/3000], Loss: 0.5077
Epoch [4/20], Batch [51/3000], Loss: 0.2683
Epoch [4/20], Batch [61/3000], Loss: 0.4228
Epoch [4/20], Batch [71/3000], Loss: 0.4337
Epoch [4/20], Batch [81/3000], Loss: 0.4009
Epoch [4/20], Batch [91/3000], Loss: 0.3974
Epoch [4/20], Batch [101/3000], Loss: 0.3459
Epoch [4/20], Batch [111/3000], Loss: 0.4616
Epoch [4/20], Batch [121/3000], Loss: 0.4142
Epoch [4/20], Batch [131/3000], Loss: 0.3207
Epoch [4/20], Batch [141/3000], Loss: 0.3573
Epoch [4/20], Batch [151/3000], Loss: 0.3642
Epoch [4/20], Batch [161/3000], Loss: 0.3679
Epoch [4/20], Batch [171/3000], Loss: 0.3794
Epoch [4/20], Batch [181/3000], Loss: 0.3873
Epoch [4/20], Batch [191/3000],

Epoch [4/20], Batch [1781/3000], Loss: 0.2956
Epoch [4/20], Batch [1791/3000], Loss: 0.4619
Epoch [4/20], Batch [1801/3000], Loss: 0.3269
Epoch [4/20], Batch [1811/3000], Loss: 0.2959
Epoch [4/20], Batch [1821/3000], Loss: 0.3569
Epoch [4/20], Batch [1831/3000], Loss: 0.3580
Epoch [4/20], Batch [1841/3000], Loss: 0.2727
Epoch [4/20], Batch [1851/3000], Loss: 0.3201
Epoch [4/20], Batch [1861/3000], Loss: 0.4483
Epoch [4/20], Batch [1871/3000], Loss: 0.3983
Epoch [4/20], Batch [1881/3000], Loss: 0.4168
Epoch [4/20], Batch [1891/3000], Loss: 0.2995
Epoch [4/20], Batch [1901/3000], Loss: 0.3274
Epoch [4/20], Batch [1911/3000], Loss: 0.3847
Epoch [4/20], Batch [1921/3000], Loss: 0.4018
Epoch [4/20], Batch [1931/3000], Loss: 0.3342
Epoch [4/20], Batch [1941/3000], Loss: 0.2735
Epoch [4/20], Batch [1951/3000], Loss: 0.3384
Epoch [4/20], Batch [1961/3000], Loss: 0.4490
Epoch [4/20], Batch [1971/3000], Loss: 0.3487
Epoch [4/20], Batch [1981/3000], Loss: 0.3458
Epoch [4/20], Batch [1991/3000], L

Epoch [5/20], Batch [571/3000], Loss: 0.3395
Epoch [5/20], Batch [581/3000], Loss: 0.3894
Epoch [5/20], Batch [591/3000], Loss: 0.3893
Epoch [5/20], Batch [601/3000], Loss: 0.4144
Epoch [5/20], Batch [611/3000], Loss: 0.3590
Epoch [5/20], Batch [621/3000], Loss: 0.3114
Epoch [5/20], Batch [631/3000], Loss: 0.4752
Epoch [5/20], Batch [641/3000], Loss: 0.3460
Epoch [5/20], Batch [651/3000], Loss: 0.2968
Epoch [5/20], Batch [661/3000], Loss: 0.3847
Epoch [5/20], Batch [671/3000], Loss: 0.3745
Epoch [5/20], Batch [681/3000], Loss: 0.3000
Epoch [5/20], Batch [691/3000], Loss: 0.4011
Epoch [5/20], Batch [701/3000], Loss: 0.3336
Epoch [5/20], Batch [711/3000], Loss: 0.3204
Epoch [5/20], Batch [721/3000], Loss: 0.4362
Epoch [5/20], Batch [731/3000], Loss: 0.3498
Epoch [5/20], Batch [741/3000], Loss: 0.3995
Epoch [5/20], Batch [751/3000], Loss: 0.3379
Epoch [5/20], Batch [761/3000], Loss: 0.3298
Epoch [5/20], Batch [771/3000], Loss: 0.4856
Epoch [5/20], Batch [781/3000], Loss: 0.3832
Epoch [5/2

Epoch [5/20], Batch [2371/3000], Loss: 0.4510
Epoch [5/20], Batch [2381/3000], Loss: 0.4636
Epoch [5/20], Batch [2391/3000], Loss: 0.3454
Epoch [5/20], Batch [2401/3000], Loss: 0.3630
Epoch [5/20], Batch [2411/3000], Loss: 0.2539
Epoch [5/20], Batch [2421/3000], Loss: 0.3387
Epoch [5/20], Batch [2431/3000], Loss: 0.3309
Epoch [5/20], Batch [2441/3000], Loss: 0.3892
Epoch [5/20], Batch [2451/3000], Loss: 0.3277
Epoch [5/20], Batch [2461/3000], Loss: 0.2600
Epoch [5/20], Batch [2471/3000], Loss: 0.3711
Epoch [5/20], Batch [2481/3000], Loss: 0.5187
Epoch [5/20], Batch [2491/3000], Loss: 0.3624
Epoch [5/20], Batch [2501/3000], Loss: 0.4738
Epoch [5/20], Batch [2511/3000], Loss: 0.3462
Epoch [5/20], Batch [2521/3000], Loss: 0.3197
Epoch [5/20], Batch [2531/3000], Loss: 0.4023
Epoch [5/20], Batch [2541/3000], Loss: 0.3546
Epoch [5/20], Batch [2551/3000], Loss: 0.3700
Epoch [5/20], Batch [2561/3000], Loss: 0.4017
Epoch [5/20], Batch [2571/3000], Loss: 0.3688
Epoch [5/20], Batch [2581/3000], L

Epoch [6/20], Batch [1171/3000], Loss: 0.4193
Epoch [6/20], Batch [1181/3000], Loss: 0.3565
Epoch [6/20], Batch [1191/3000], Loss: 0.3716
Epoch [6/20], Batch [1201/3000], Loss: 0.4043
Epoch [6/20], Batch [1211/3000], Loss: 0.4326
Epoch [6/20], Batch [1221/3000], Loss: 0.4487
Epoch [6/20], Batch [1231/3000], Loss: 0.2831
Epoch [6/20], Batch [1241/3000], Loss: 0.3354
Epoch [6/20], Batch [1251/3000], Loss: 0.3413
Epoch [6/20], Batch [1261/3000], Loss: 0.3102
Epoch [6/20], Batch [1271/3000], Loss: 0.4234
Epoch [6/20], Batch [1281/3000], Loss: 0.3743
Epoch [6/20], Batch [1291/3000], Loss: 0.3406
Epoch [6/20], Batch [1301/3000], Loss: 0.4032
Epoch [6/20], Batch [1311/3000], Loss: 0.4793
Epoch [6/20], Batch [1321/3000], Loss: 0.2602
Epoch [6/20], Batch [1331/3000], Loss: 0.3720
Epoch [6/20], Batch [1341/3000], Loss: 0.2658
Epoch [6/20], Batch [1351/3000], Loss: 0.3975
Epoch [6/20], Batch [1361/3000], Loss: 0.3315
Epoch [6/20], Batch [1371/3000], Loss: 0.4290
Epoch [6/20], Batch [1381/3000], L

Epoch [6/20], Batch [2961/3000], Loss: 0.2885
Epoch [6/20], Batch [2971/3000], Loss: 0.2949
Epoch [6/20], Batch [2981/3000], Loss: 0.3249
Epoch [6/20], Batch [2991/3000], Loss: 0.3579
Epoch 6/20, Average Loss: 0.3695
Epoch [7/20], Batch [1/3000], Loss: 0.3901
Epoch [7/20], Batch [11/3000], Loss: 0.4306
Epoch [7/20], Batch [21/3000], Loss: 0.3390
Epoch [7/20], Batch [31/3000], Loss: 0.4096
Epoch [7/20], Batch [41/3000], Loss: 0.3643
Epoch [7/20], Batch [51/3000], Loss: 0.4400
Epoch [7/20], Batch [61/3000], Loss: 0.3621
Epoch [7/20], Batch [71/3000], Loss: 0.3337
Epoch [7/20], Batch [81/3000], Loss: 0.3627
Epoch [7/20], Batch [91/3000], Loss: 0.3664
Epoch [7/20], Batch [101/3000], Loss: 0.3854
Epoch [7/20], Batch [111/3000], Loss: 0.3252
Epoch [7/20], Batch [121/3000], Loss: 0.3982
Epoch [7/20], Batch [131/3000], Loss: 0.4332
Epoch [7/20], Batch [141/3000], Loss: 0.3720
Epoch [7/20], Batch [151/3000], Loss: 0.2776
Epoch [7/20], Batch [161/3000], Loss: 0.3924
Epoch [7/20], Batch [171/3000

Epoch [7/20], Batch [1761/3000], Loss: 0.3706
Epoch [7/20], Batch [1771/3000], Loss: 0.3666
Epoch [7/20], Batch [1781/3000], Loss: 0.3619
Epoch [7/20], Batch [1791/3000], Loss: 0.2909
Epoch [7/20], Batch [1801/3000], Loss: 0.3405
Epoch [7/20], Batch [1811/3000], Loss: 0.4388
Epoch [7/20], Batch [1821/3000], Loss: 0.3780
Epoch [7/20], Batch [1831/3000], Loss: 0.3352
Epoch [7/20], Batch [1841/3000], Loss: 0.3730
Epoch [7/20], Batch [1851/3000], Loss: 0.3897
Epoch [7/20], Batch [1861/3000], Loss: 0.2996
Epoch [7/20], Batch [1871/3000], Loss: 0.4716
Epoch [7/20], Batch [1881/3000], Loss: 0.3627
Epoch [7/20], Batch [1891/3000], Loss: 0.4257
Epoch [7/20], Batch [1901/3000], Loss: 0.3382
Epoch [7/20], Batch [1911/3000], Loss: 0.3699
Epoch [7/20], Batch [1921/3000], Loss: 0.4737
Epoch [7/20], Batch [1931/3000], Loss: 0.4578
Epoch [7/20], Batch [1941/3000], Loss: 0.2475
Epoch [7/20], Batch [1951/3000], Loss: 0.4193
Epoch [7/20], Batch [1961/3000], Loss: 0.3933
Epoch [7/20], Batch [1971/3000], L

Epoch [8/20], Batch [551/3000], Loss: 0.3724
Epoch [8/20], Batch [561/3000], Loss: 0.4422
Epoch [8/20], Batch [571/3000], Loss: 0.3159
Epoch [8/20], Batch [581/3000], Loss: 0.4056
Epoch [8/20], Batch [591/3000], Loss: 0.2971
Epoch [8/20], Batch [601/3000], Loss: 0.4337
Epoch [8/20], Batch [611/3000], Loss: 0.4474
Epoch [8/20], Batch [621/3000], Loss: 0.3533
Epoch [8/20], Batch [631/3000], Loss: 0.3666
Epoch [8/20], Batch [641/3000], Loss: 0.2581
Epoch [8/20], Batch [651/3000], Loss: 0.3026
Epoch [8/20], Batch [661/3000], Loss: 0.4533
Epoch [8/20], Batch [671/3000], Loss: 0.3781
Epoch [8/20], Batch [681/3000], Loss: 0.3897
Epoch [8/20], Batch [691/3000], Loss: 0.4371
Epoch [8/20], Batch [701/3000], Loss: 0.3914
Epoch [8/20], Batch [711/3000], Loss: 0.2834
Epoch [8/20], Batch [721/3000], Loss: 0.4440
Epoch [8/20], Batch [731/3000], Loss: 0.3327
Epoch [8/20], Batch [741/3000], Loss: 0.3432
Epoch [8/20], Batch [751/3000], Loss: 0.3172
Epoch [8/20], Batch [761/3000], Loss: 0.3349
Epoch [8/2

Epoch [8/20], Batch [2351/3000], Loss: 0.2947
Epoch [8/20], Batch [2361/3000], Loss: 0.4022
Epoch [8/20], Batch [2371/3000], Loss: 0.3746
Epoch [8/20], Batch [2381/3000], Loss: 0.4454
Epoch [8/20], Batch [2391/3000], Loss: 0.3678
Epoch [8/20], Batch [2401/3000], Loss: 0.3095
Epoch [8/20], Batch [2411/3000], Loss: 0.3895
Epoch [8/20], Batch [2421/3000], Loss: 0.4613
Epoch [8/20], Batch [2431/3000], Loss: 0.3187
Epoch [8/20], Batch [2441/3000], Loss: 0.3737
Epoch [8/20], Batch [2451/3000], Loss: 0.4151
Epoch [8/20], Batch [2461/3000], Loss: 0.3763
Epoch [8/20], Batch [2471/3000], Loss: 0.3931
Epoch [8/20], Batch [2481/3000], Loss: 0.3062
Epoch [8/20], Batch [2491/3000], Loss: 0.2648
Epoch [8/20], Batch [2501/3000], Loss: 0.2552
Epoch [8/20], Batch [2511/3000], Loss: 0.2822
Epoch [8/20], Batch [2521/3000], Loss: 0.3500
Epoch [8/20], Batch [2531/3000], Loss: 0.3760
Epoch [8/20], Batch [2541/3000], Loss: 0.4460
Epoch [8/20], Batch [2551/3000], Loss: 0.3197
Epoch [8/20], Batch [2561/3000], L

Epoch [9/20], Batch [1151/3000], Loss: 0.3179
Epoch [9/20], Batch [1161/3000], Loss: 0.4197
Epoch [9/20], Batch [1171/3000], Loss: 0.3905
Epoch [9/20], Batch [1181/3000], Loss: 0.3636
Epoch [9/20], Batch [1191/3000], Loss: 0.3914
Epoch [9/20], Batch [1201/3000], Loss: 0.3048
Epoch [9/20], Batch [1211/3000], Loss: 0.3104
Epoch [9/20], Batch [1221/3000], Loss: 0.3662
Epoch [9/20], Batch [1231/3000], Loss: 0.4444
Epoch [9/20], Batch [1241/3000], Loss: 0.3016
Epoch [9/20], Batch [1251/3000], Loss: 0.3729
Epoch [9/20], Batch [1261/3000], Loss: 0.4054
Epoch [9/20], Batch [1271/3000], Loss: 0.4025
Epoch [9/20], Batch [1281/3000], Loss: 0.3138
Epoch [9/20], Batch [1291/3000], Loss: 0.4899
Epoch [9/20], Batch [1301/3000], Loss: 0.5051
Epoch [9/20], Batch [1311/3000], Loss: 0.4161
Epoch [9/20], Batch [1321/3000], Loss: 0.3518
Epoch [9/20], Batch [1331/3000], Loss: 0.3710
Epoch [9/20], Batch [1341/3000], Loss: 0.4093
Epoch [9/20], Batch [1351/3000], Loss: 0.3493
Epoch [9/20], Batch [1361/3000], L

Epoch [9/20], Batch [2941/3000], Loss: 0.3580
Epoch [9/20], Batch [2951/3000], Loss: 0.2877
Epoch [9/20], Batch [2961/3000], Loss: 0.4801
Epoch [9/20], Batch [2971/3000], Loss: 0.3945
Epoch [9/20], Batch [2981/3000], Loss: 0.3857
Epoch [9/20], Batch [2991/3000], Loss: 0.3652
Epoch 9/20, Average Loss: 0.3612
Epoch [10/20], Batch [1/3000], Loss: 0.3196
Epoch [10/20], Batch [11/3000], Loss: 0.2579
Epoch [10/20], Batch [21/3000], Loss: 0.4048
Epoch [10/20], Batch [31/3000], Loss: 0.2122
Epoch [10/20], Batch [41/3000], Loss: 0.2941
Epoch [10/20], Batch [51/3000], Loss: 0.3522
Epoch [10/20], Batch [61/3000], Loss: 0.3348
Epoch [10/20], Batch [71/3000], Loss: 0.2543
Epoch [10/20], Batch [81/3000], Loss: 0.3702
Epoch [10/20], Batch [91/3000], Loss: 0.3719
Epoch [10/20], Batch [101/3000], Loss: 0.3448
Epoch [10/20], Batch [111/3000], Loss: 0.4108
Epoch [10/20], Batch [121/3000], Loss: 0.3935
Epoch [10/20], Batch [131/3000], Loss: 0.3214
Epoch [10/20], Batch [141/3000], Loss: 0.3883
Epoch [10/20

Epoch [10/20], Batch [1711/3000], Loss: 0.3878
Epoch [10/20], Batch [1721/3000], Loss: 0.3138
Epoch [10/20], Batch [1731/3000], Loss: 0.3776
Epoch [10/20], Batch [1741/3000], Loss: 0.3906
Epoch [10/20], Batch [1751/3000], Loss: 0.3395
Epoch [10/20], Batch [1761/3000], Loss: 0.2864
Epoch [10/20], Batch [1771/3000], Loss: 0.3901
Epoch [10/20], Batch [1781/3000], Loss: 0.4094
Epoch [10/20], Batch [1791/3000], Loss: 0.2596
Epoch [10/20], Batch [1801/3000], Loss: 0.4095
Epoch [10/20], Batch [1811/3000], Loss: 0.3351
Epoch [10/20], Batch [1821/3000], Loss: 0.3833
Epoch [10/20], Batch [1831/3000], Loss: 0.3051
Epoch [10/20], Batch [1841/3000], Loss: 0.3323
Epoch [10/20], Batch [1851/3000], Loss: 0.2932
Epoch [10/20], Batch [1861/3000], Loss: 0.3549
Epoch [10/20], Batch [1871/3000], Loss: 0.3496
Epoch [10/20], Batch [1881/3000], Loss: 0.3594
Epoch [10/20], Batch [1891/3000], Loss: 0.3735
Epoch [10/20], Batch [1901/3000], Loss: 0.3188
Epoch [10/20], Batch [1911/3000], Loss: 0.4375
Epoch [10/20]

Epoch [11/20], Batch [461/3000], Loss: 0.3153
Epoch [11/20], Batch [471/3000], Loss: 0.2900
Epoch [11/20], Batch [481/3000], Loss: 0.3742
Epoch [11/20], Batch [491/3000], Loss: 0.3153
Epoch [11/20], Batch [501/3000], Loss: 0.3738
Epoch [11/20], Batch [511/3000], Loss: 0.3400
Epoch [11/20], Batch [521/3000], Loss: 0.4335
Epoch [11/20], Batch [531/3000], Loss: 0.2625
Epoch [11/20], Batch [541/3000], Loss: 0.2863
Epoch [11/20], Batch [551/3000], Loss: 0.3231
Epoch [11/20], Batch [561/3000], Loss: 0.3029
Epoch [11/20], Batch [571/3000], Loss: 0.2975
Epoch [11/20], Batch [581/3000], Loss: 0.3090
Epoch [11/20], Batch [591/3000], Loss: 0.4281
Epoch [11/20], Batch [601/3000], Loss: 0.3638
Epoch [11/20], Batch [611/3000], Loss: 0.2833
Epoch [11/20], Batch [621/3000], Loss: 0.3392
Epoch [11/20], Batch [631/3000], Loss: 0.2700
Epoch [11/20], Batch [641/3000], Loss: 0.3363
Epoch [11/20], Batch [651/3000], Loss: 0.3423
Epoch [11/20], Batch [661/3000], Loss: 0.2658
Epoch [11/20], Batch [671/3000], L

Epoch [11/20], Batch [2221/3000], Loss: 0.3485
Epoch [11/20], Batch [2231/3000], Loss: 0.3658
Epoch [11/20], Batch [2241/3000], Loss: 0.2683
Epoch [11/20], Batch [2251/3000], Loss: 0.2710
Epoch [11/20], Batch [2261/3000], Loss: 0.3020
Epoch [11/20], Batch [2271/3000], Loss: 0.4265
Epoch [11/20], Batch [2281/3000], Loss: 0.3566
Epoch [11/20], Batch [2291/3000], Loss: 0.3224
Epoch [11/20], Batch [2301/3000], Loss: 0.3284
Epoch [11/20], Batch [2311/3000], Loss: 0.3255
Epoch [11/20], Batch [2321/3000], Loss: 0.3305
Epoch [11/20], Batch [2331/3000], Loss: 0.3799
Epoch [11/20], Batch [2341/3000], Loss: 0.4252
Epoch [11/20], Batch [2351/3000], Loss: 0.2644
Epoch [11/20], Batch [2361/3000], Loss: 0.4428
Epoch [11/20], Batch [2371/3000], Loss: 0.3662
Epoch [11/20], Batch [2381/3000], Loss: 0.2734
Epoch [11/20], Batch [2391/3000], Loss: 0.3878
Epoch [11/20], Batch [2401/3000], Loss: 0.3835
Epoch [11/20], Batch [2411/3000], Loss: 0.3632
Epoch [11/20], Batch [2421/3000], Loss: 0.3940
Epoch [11/20]

Epoch [12/20], Batch [981/3000], Loss: 0.3274
Epoch [12/20], Batch [991/3000], Loss: 0.2916
Epoch [12/20], Batch [1001/3000], Loss: 0.3302
Epoch [12/20], Batch [1011/3000], Loss: 0.3677
Epoch [12/20], Batch [1021/3000], Loss: 0.3892
Epoch [12/20], Batch [1031/3000], Loss: 0.3756
Epoch [12/20], Batch [1041/3000], Loss: 0.3026
Epoch [12/20], Batch [1051/3000], Loss: 0.2602
Epoch [12/20], Batch [1061/3000], Loss: 0.3406
Epoch [12/20], Batch [1071/3000], Loss: 0.2095
Epoch [12/20], Batch [1081/3000], Loss: 0.3063
Epoch [12/20], Batch [1091/3000], Loss: 0.3131
Epoch [12/20], Batch [1101/3000], Loss: 0.2959
Epoch [12/20], Batch [1111/3000], Loss: 0.2489
Epoch [12/20], Batch [1121/3000], Loss: 0.3653
Epoch [12/20], Batch [1131/3000], Loss: 0.4138
Epoch [12/20], Batch [1141/3000], Loss: 0.3357
Epoch [12/20], Batch [1151/3000], Loss: 0.3327
Epoch [12/20], Batch [1161/3000], Loss: 0.2788
Epoch [12/20], Batch [1171/3000], Loss: 0.3307
Epoch [12/20], Batch [1181/3000], Loss: 0.2782
Epoch [12/20], 

Epoch [12/20], Batch [2731/3000], Loss: 0.3252
Epoch [12/20], Batch [2741/3000], Loss: 0.3153
Epoch [12/20], Batch [2751/3000], Loss: 0.4735
Epoch [12/20], Batch [2761/3000], Loss: 0.3193
Epoch [12/20], Batch [2771/3000], Loss: 0.2896
Epoch [12/20], Batch [2781/3000], Loss: 0.3394
Epoch [12/20], Batch [2791/3000], Loss: 0.4107
Epoch [12/20], Batch [2801/3000], Loss: 0.3233
Epoch [12/20], Batch [2811/3000], Loss: 0.3038
Epoch [12/20], Batch [2821/3000], Loss: 0.3460
Epoch [12/20], Batch [2831/3000], Loss: 0.3132
Epoch [12/20], Batch [2841/3000], Loss: 0.3374
Epoch [12/20], Batch [2851/3000], Loss: 0.3675
Epoch [12/20], Batch [2861/3000], Loss: 0.3620
Epoch [12/20], Batch [2871/3000], Loss: 0.4061
Epoch [12/20], Batch [2881/3000], Loss: 0.3461
Epoch [12/20], Batch [2891/3000], Loss: 0.3344
Epoch [12/20], Batch [2901/3000], Loss: 0.3319
Epoch [12/20], Batch [2911/3000], Loss: 0.3730
Epoch [12/20], Batch [2921/3000], Loss: 0.2919
Epoch [12/20], Batch [2931/3000], Loss: 0.2781
Epoch [12/20]

Epoch [13/20], Batch [1491/3000], Loss: 0.2707
Epoch [13/20], Batch [1501/3000], Loss: 0.2135
Epoch [13/20], Batch [1511/3000], Loss: 0.2156
Epoch [13/20], Batch [1521/3000], Loss: 0.3198
Epoch [13/20], Batch [1531/3000], Loss: 0.2064
Epoch [13/20], Batch [1541/3000], Loss: 0.2577
Epoch [13/20], Batch [1551/3000], Loss: 0.3472
Epoch [13/20], Batch [1561/3000], Loss: 0.3061
Epoch [13/20], Batch [1571/3000], Loss: 0.3692
Epoch [13/20], Batch [1581/3000], Loss: 0.2672
Epoch [13/20], Batch [1591/3000], Loss: 0.3119
Epoch [13/20], Batch [1601/3000], Loss: 0.2481
Epoch [13/20], Batch [1611/3000], Loss: 0.1847
Epoch [13/20], Batch [1621/3000], Loss: 0.4007
Epoch [13/20], Batch [1631/3000], Loss: 0.2478
Epoch [13/20], Batch [1641/3000], Loss: 0.2360
Epoch [13/20], Batch [1651/3000], Loss: 0.3217
Epoch [13/20], Batch [1661/3000], Loss: 0.3731
Epoch [13/20], Batch [1671/3000], Loss: 0.2428
Epoch [13/20], Batch [1681/3000], Loss: 0.2622
Epoch [13/20], Batch [1691/3000], Loss: 0.3361
Epoch [13/20]

Epoch [14/20], Batch [241/3000], Loss: 0.2789
Epoch [14/20], Batch [251/3000], Loss: 0.2272
Epoch [14/20], Batch [261/3000], Loss: 0.2701
Epoch [14/20], Batch [271/3000], Loss: 0.1898
Epoch [14/20], Batch [281/3000], Loss: 0.3060
Epoch [14/20], Batch [291/3000], Loss: 0.3486
Epoch [14/20], Batch [301/3000], Loss: 0.2656
Epoch [14/20], Batch [311/3000], Loss: 0.1987
Epoch [14/20], Batch [321/3000], Loss: 0.2971
Epoch [14/20], Batch [331/3000], Loss: 0.3069
Epoch [14/20], Batch [341/3000], Loss: 0.2811
Epoch [14/20], Batch [351/3000], Loss: 0.2277
Epoch [14/20], Batch [361/3000], Loss: 0.2545
Epoch [14/20], Batch [371/3000], Loss: 0.2916
Epoch [14/20], Batch [381/3000], Loss: 0.3034
Epoch [14/20], Batch [391/3000], Loss: 0.2160
Epoch [14/20], Batch [401/3000], Loss: 0.2044
Epoch [14/20], Batch [411/3000], Loss: 0.2668
Epoch [14/20], Batch [421/3000], Loss: 0.2007
Epoch [14/20], Batch [431/3000], Loss: 0.2438
Epoch [14/20], Batch [441/3000], Loss: 0.2072
Epoch [14/20], Batch [451/3000], L

Epoch [14/20], Batch [2001/3000], Loss: 0.2568
Epoch [14/20], Batch [2011/3000], Loss: 0.2508
Epoch [14/20], Batch [2021/3000], Loss: 0.2840
Epoch [14/20], Batch [2031/3000], Loss: 0.1911
Epoch [14/20], Batch [2041/3000], Loss: 0.1805
Epoch [14/20], Batch [2051/3000], Loss: 0.2444
Epoch [14/20], Batch [2061/3000], Loss: 0.2752
Epoch [14/20], Batch [2071/3000], Loss: 0.2675
Epoch [14/20], Batch [2081/3000], Loss: 0.3282
Epoch [14/20], Batch [2091/3000], Loss: 0.3036
Epoch [14/20], Batch [2101/3000], Loss: 0.3082
Epoch [14/20], Batch [2111/3000], Loss: 0.2474
Epoch [14/20], Batch [2121/3000], Loss: 0.1889
Epoch [14/20], Batch [2131/3000], Loss: 0.3137
Epoch [14/20], Batch [2141/3000], Loss: 0.2123
Epoch [14/20], Batch [2151/3000], Loss: 0.2034
Epoch [14/20], Batch [2161/3000], Loss: 0.2079
Epoch [14/20], Batch [2171/3000], Loss: 0.3165
Epoch [14/20], Batch [2181/3000], Loss: 0.2678
Epoch [14/20], Batch [2191/3000], Loss: 0.3075
Epoch [14/20], Batch [2201/3000], Loss: 0.3289
Epoch [14/20]

Epoch [15/20], Batch [761/3000], Loss: 0.2457
Epoch [15/20], Batch [771/3000], Loss: 0.2462
Epoch [15/20], Batch [781/3000], Loss: 0.2208
Epoch [15/20], Batch [791/3000], Loss: 0.2015
Epoch [15/20], Batch [801/3000], Loss: 0.2081
Epoch [15/20], Batch [811/3000], Loss: 0.2472
Epoch [15/20], Batch [821/3000], Loss: 0.2192
Epoch [15/20], Batch [831/3000], Loss: 0.2780
Epoch [15/20], Batch [841/3000], Loss: 0.1637
Epoch [15/20], Batch [851/3000], Loss: 0.1932
Epoch [15/20], Batch [861/3000], Loss: 0.2220
Epoch [15/20], Batch [871/3000], Loss: 0.1854
Epoch [15/20], Batch [881/3000], Loss: 0.2850
Epoch [15/20], Batch [891/3000], Loss: 0.1934
Epoch [15/20], Batch [901/3000], Loss: 0.1830
Epoch [15/20], Batch [911/3000], Loss: 0.1897
Epoch [15/20], Batch [921/3000], Loss: 0.2148
Epoch [15/20], Batch [931/3000], Loss: 0.2680
Epoch [15/20], Batch [941/3000], Loss: 0.1951
Epoch [15/20], Batch [951/3000], Loss: 0.2418
Epoch [15/20], Batch [961/3000], Loss: 0.2395
Epoch [15/20], Batch [971/3000], L

Epoch [15/20], Batch [2511/3000], Loss: 0.2862
Epoch [15/20], Batch [2521/3000], Loss: 0.2046
Epoch [15/20], Batch [2531/3000], Loss: 0.2603
Epoch [15/20], Batch [2541/3000], Loss: 0.2601
Epoch [15/20], Batch [2551/3000], Loss: 0.1493
Epoch [15/20], Batch [2561/3000], Loss: 0.1933
Epoch [15/20], Batch [2571/3000], Loss: 0.3034
Epoch [15/20], Batch [2581/3000], Loss: 0.2187
Epoch [15/20], Batch [2591/3000], Loss: 0.2145
Epoch [15/20], Batch [2601/3000], Loss: 0.2384
Epoch [15/20], Batch [2611/3000], Loss: 0.2423
Epoch [15/20], Batch [2621/3000], Loss: 0.1707
Epoch [15/20], Batch [2631/3000], Loss: 0.2895
Epoch [15/20], Batch [2641/3000], Loss: 0.1528
Epoch [15/20], Batch [2651/3000], Loss: 0.1690
Epoch [15/20], Batch [2661/3000], Loss: 0.2935
Epoch [15/20], Batch [2671/3000], Loss: 0.2544
Epoch [15/20], Batch [2681/3000], Loss: 0.2142
Epoch [15/20], Batch [2691/3000], Loss: 0.2416
Epoch [15/20], Batch [2701/3000], Loss: 0.2223
Epoch [15/20], Batch [2711/3000], Loss: 0.2302
Epoch [15/20]

Epoch [16/20], Batch [1271/3000], Loss: 0.2107
Epoch [16/20], Batch [1281/3000], Loss: 0.1984
Epoch [16/20], Batch [1291/3000], Loss: 0.2124
Epoch [16/20], Batch [1301/3000], Loss: 0.1951
Epoch [16/20], Batch [1311/3000], Loss: 0.1652
Epoch [16/20], Batch [1321/3000], Loss: 0.1937
Epoch [16/20], Batch [1331/3000], Loss: 0.2065
Epoch [16/20], Batch [1341/3000], Loss: 0.1475
Epoch [16/20], Batch [1351/3000], Loss: 0.1644
Epoch [16/20], Batch [1361/3000], Loss: 0.1488
Epoch [16/20], Batch [1371/3000], Loss: 0.1551
Epoch [16/20], Batch [1381/3000], Loss: 0.1499
Epoch [16/20], Batch [1391/3000], Loss: 0.1347
Epoch [16/20], Batch [1401/3000], Loss: 0.1788
Epoch [16/20], Batch [1411/3000], Loss: 0.1475
Epoch [16/20], Batch [1421/3000], Loss: 0.1518
Epoch [16/20], Batch [1431/3000], Loss: 0.1846
Epoch [16/20], Batch [1441/3000], Loss: 0.1899
Epoch [16/20], Batch [1451/3000], Loss: 0.2255
Epoch [16/20], Batch [1461/3000], Loss: 0.1765
Epoch [16/20], Batch [1471/3000], Loss: 0.1781
Epoch [16/20]

Epoch [17/20], Batch [11/3000], Loss: 0.2047
Epoch [17/20], Batch [21/3000], Loss: 0.1183
Epoch [17/20], Batch [31/3000], Loss: 0.0936
Epoch [17/20], Batch [41/3000], Loss: 0.1587
Epoch [17/20], Batch [51/3000], Loss: 0.1332
Epoch [17/20], Batch [61/3000], Loss: 0.1664
Epoch [17/20], Batch [71/3000], Loss: 0.1472
Epoch [17/20], Batch [81/3000], Loss: 0.1080
Epoch [17/20], Batch [91/3000], Loss: 0.1218
Epoch [17/20], Batch [101/3000], Loss: 0.1157
Epoch [17/20], Batch [111/3000], Loss: 0.1209
Epoch [17/20], Batch [121/3000], Loss: 0.1827
Epoch [17/20], Batch [131/3000], Loss: 0.1822
Epoch [17/20], Batch [141/3000], Loss: 0.1919
Epoch [17/20], Batch [151/3000], Loss: 0.1325
Epoch [17/20], Batch [161/3000], Loss: 0.1591
Epoch [17/20], Batch [171/3000], Loss: 0.1570
Epoch [17/20], Batch [181/3000], Loss: 0.1426
Epoch [17/20], Batch [191/3000], Loss: 0.1742
Epoch [17/20], Batch [201/3000], Loss: 0.1368
Epoch [17/20], Batch [211/3000], Loss: 0.1797
Epoch [17/20], Batch [221/3000], Loss: 0.12

Epoch [17/20], Batch [1781/3000], Loss: 0.1342
Epoch [17/20], Batch [1791/3000], Loss: 0.2267
Epoch [17/20], Batch [1801/3000], Loss: 0.2260
Epoch [17/20], Batch [1811/3000], Loss: 0.1449
Epoch [17/20], Batch [1821/3000], Loss: 0.1221
Epoch [17/20], Batch [1831/3000], Loss: 0.1399
Epoch [17/20], Batch [1841/3000], Loss: 0.1330
Epoch [17/20], Batch [1851/3000], Loss: 0.1492
Epoch [17/20], Batch [1861/3000], Loss: 0.1322
Epoch [17/20], Batch [1871/3000], Loss: 0.1530
Epoch [17/20], Batch [1881/3000], Loss: 0.1593
Epoch [17/20], Batch [1891/3000], Loss: 0.1552
Epoch [17/20], Batch [1901/3000], Loss: 0.1544
Epoch [17/20], Batch [1911/3000], Loss: 0.1656
Epoch [17/20], Batch [1921/3000], Loss: 0.1471
Epoch [17/20], Batch [1931/3000], Loss: 0.1561
Epoch [17/20], Batch [1941/3000], Loss: 0.1582
Epoch [17/20], Batch [1951/3000], Loss: 0.1550
Epoch [17/20], Batch [1961/3000], Loss: 0.1210
Epoch [17/20], Batch [1971/3000], Loss: 0.1151
Epoch [17/20], Batch [1981/3000], Loss: 0.1345
Epoch [17/20]

Epoch [18/20], Batch [531/3000], Loss: 0.1422
Epoch [18/20], Batch [541/3000], Loss: 0.1307
Epoch [18/20], Batch [551/3000], Loss: 0.1169
Epoch [18/20], Batch [561/3000], Loss: 0.1460
Epoch [18/20], Batch [571/3000], Loss: 0.1315
Epoch [18/20], Batch [581/3000], Loss: 0.1692
Epoch [18/20], Batch [591/3000], Loss: 0.1519
Epoch [18/20], Batch [601/3000], Loss: 0.0929
Epoch [18/20], Batch [611/3000], Loss: 0.0958
Epoch [18/20], Batch [621/3000], Loss: 0.1294
Epoch [18/20], Batch [631/3000], Loss: 0.0934
Epoch [18/20], Batch [641/3000], Loss: 0.1335
Epoch [18/20], Batch [651/3000], Loss: 0.1607
Epoch [18/20], Batch [661/3000], Loss: 0.1040
Epoch [18/20], Batch [671/3000], Loss: 0.0669
Epoch [18/20], Batch [681/3000], Loss: 0.0955
Epoch [18/20], Batch [691/3000], Loss: 0.1263
Epoch [18/20], Batch [701/3000], Loss: 0.1116
Epoch [18/20], Batch [711/3000], Loss: 0.1443
Epoch [18/20], Batch [721/3000], Loss: 0.1206
Epoch [18/20], Batch [731/3000], Loss: 0.0900
Epoch [18/20], Batch [741/3000], L

Epoch [18/20], Batch [2291/3000], Loss: 0.1574
Epoch [18/20], Batch [2301/3000], Loss: 0.1644
Epoch [18/20], Batch [2311/3000], Loss: 0.1701
Epoch [18/20], Batch [2321/3000], Loss: 0.0747
Epoch [18/20], Batch [2331/3000], Loss: 0.1663
Epoch [18/20], Batch [2341/3000], Loss: 0.1507
Epoch [18/20], Batch [2351/3000], Loss: 0.1259
Epoch [18/20], Batch [2361/3000], Loss: 0.1326
Epoch [18/20], Batch [2371/3000], Loss: 0.1599
Epoch [18/20], Batch [2381/3000], Loss: 0.1368
Epoch [18/20], Batch [2391/3000], Loss: 0.1251
Epoch [18/20], Batch [2401/3000], Loss: 0.1141
Epoch [18/20], Batch [2411/3000], Loss: 0.1577
Epoch [18/20], Batch [2421/3000], Loss: 0.1269
Epoch [18/20], Batch [2431/3000], Loss: 0.1169
Epoch [18/20], Batch [2441/3000], Loss: 0.1165
Epoch [18/20], Batch [2451/3000], Loss: 0.1342
Epoch [18/20], Batch [2461/3000], Loss: 0.1389
Epoch [18/20], Batch [2471/3000], Loss: 0.1676
Epoch [18/20], Batch [2481/3000], Loss: 0.1207
Epoch [18/20], Batch [2491/3000], Loss: 0.1171
Epoch [18/20]

Epoch [19/20], Batch [1051/3000], Loss: 0.0960
Epoch [19/20], Batch [1061/3000], Loss: 0.0890
Epoch [19/20], Batch [1071/3000], Loss: 0.0770
Epoch [19/20], Batch [1081/3000], Loss: 0.1403
Epoch [19/20], Batch [1091/3000], Loss: 0.0861
Epoch [19/20], Batch [1101/3000], Loss: 0.0775
Epoch [19/20], Batch [1111/3000], Loss: 0.0931
Epoch [19/20], Batch [1121/3000], Loss: 0.0726
Epoch [19/20], Batch [1131/3000], Loss: 0.0967
Epoch [19/20], Batch [1141/3000], Loss: 0.0805
Epoch [19/20], Batch [1151/3000], Loss: 0.1032
Epoch [19/20], Batch [1161/3000], Loss: 0.0965
Epoch [19/20], Batch [1171/3000], Loss: 0.1082
Epoch [19/20], Batch [1181/3000], Loss: 0.1220
Epoch [19/20], Batch [1191/3000], Loss: 0.1210
Epoch [19/20], Batch [1201/3000], Loss: 0.1133
Epoch [19/20], Batch [1211/3000], Loss: 0.0997
Epoch [19/20], Batch [1221/3000], Loss: 0.1019
Epoch [19/20], Batch [1231/3000], Loss: 0.1340
Epoch [19/20], Batch [1241/3000], Loss: 0.1519
Epoch [19/20], Batch [1251/3000], Loss: 0.0967
Epoch [19/20]

Epoch [19/20], Batch [2801/3000], Loss: 0.1650
Epoch [19/20], Batch [2811/3000], Loss: 0.1344
Epoch [19/20], Batch [2821/3000], Loss: 0.1041
Epoch [19/20], Batch [2831/3000], Loss: 0.1119
Epoch [19/20], Batch [2841/3000], Loss: 0.0921
Epoch [19/20], Batch [2851/3000], Loss: 0.0866
Epoch [19/20], Batch [2861/3000], Loss: 0.1428
Epoch [19/20], Batch [2871/3000], Loss: 0.1234
Epoch [19/20], Batch [2881/3000], Loss: 0.1047
Epoch [19/20], Batch [2891/3000], Loss: 0.0844
Epoch [19/20], Batch [2901/3000], Loss: 0.1166
Epoch [19/20], Batch [2911/3000], Loss: 0.1417
Epoch [19/20], Batch [2921/3000], Loss: 0.1154
Epoch [19/20], Batch [2931/3000], Loss: 0.0950
Epoch [19/20], Batch [2941/3000], Loss: 0.1302
Epoch [19/20], Batch [2951/3000], Loss: 0.1020
Epoch [19/20], Batch [2961/3000], Loss: 0.1766
Epoch [19/20], Batch [2971/3000], Loss: 0.1189
Epoch [19/20], Batch [2981/3000], Loss: 0.1146
Epoch [19/20], Batch [2991/3000], Loss: 0.1074
Epoch 19/20, Average Loss: 0.1071
Epoch [20/20], Batch [1/30

Epoch [20/20], Batch [1561/3000], Loss: 0.0842
Epoch [20/20], Batch [1571/3000], Loss: 0.0822
Epoch [20/20], Batch [1581/3000], Loss: 0.0981
Epoch [20/20], Batch [1591/3000], Loss: 0.0789
Epoch [20/20], Batch [1601/3000], Loss: 0.0886
Epoch [20/20], Batch [1611/3000], Loss: 0.1478
Epoch [20/20], Batch [1621/3000], Loss: 0.0859
Epoch [20/20], Batch [1631/3000], Loss: 0.0832
Epoch [20/20], Batch [1641/3000], Loss: 0.0750
Epoch [20/20], Batch [1651/3000], Loss: 0.0781
Epoch [20/20], Batch [1661/3000], Loss: 0.0932
Epoch [20/20], Batch [1671/3000], Loss: 0.0905
Epoch [20/20], Batch [1681/3000], Loss: 0.0773
Epoch [20/20], Batch [1691/3000], Loss: 0.0799
Epoch [20/20], Batch [1701/3000], Loss: 0.1154
Epoch [20/20], Batch [1711/3000], Loss: 0.0662
Epoch [20/20], Batch [1721/3000], Loss: 0.1102
Epoch [20/20], Batch [1731/3000], Loss: 0.0943
Epoch [20/20], Batch [1741/3000], Loss: 0.0998
Epoch [20/20], Batch [1751/3000], Loss: 0.1170
Epoch [20/20], Batch [1761/3000], Loss: 0.0642
Epoch [20/20]

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class DigitDataset(Dataset):
    def __init__(self, data_dir, transform=None, max_boxes=10):
        self.data_dir = data_dir
        self.transform = transform
        self.images = [img for img in os.listdir(data_dir) if img.endswith('.png')]
        self.max_boxes = max_boxes
        self.annotations = self._load_annotations()

    def _load_annotations(self):
        annotations = {}
        for img in self.images:
            annotation_path = os.path.join(self.data_dir, img.replace('.png', '.txt'))
            if os.path.exists(annotation_path):
                with open(annotation_path, 'r') as f:
                    boxes = []
                    for line in f:
                        boxes.append([float(x) for x in line.strip().split()])
                    annotations[img] = boxes
        return annotations

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        targets = self.annotations.get(img_name, [])
        # Pad targets to fixed size
        padded_targets = np.zeros((self.max_boxes, 5))
        if len(targets) > 0:
            targets = np.array(targets)
            padded_targets[:len(targets), :5] = targets

        padded_targets = torch.tensor(padded_targets, dtype=torch.float32)
        return image, padded_targets

train_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

class DigitDetector(nn.Module):
    def __init__(self, num_boxes=10):
        super(DigitDetector, self).__init__()
        self.num_boxes = num_boxes
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_boxes * 5)  # 5 predictions (class + 4 coordinates) per box

    def forward(self, x):
        return self.model(x).view(-1, self.num_boxes, 5)

model = DigitDetector().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, criterion, optimizer, dataloader, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            if i % 10 == 0:  # Display loss every 10 batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {epoch_loss:.4f}')

train_dataset = DigitDataset(data_dir='HW6_train', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

train_model(model, criterion, optimizer, train_loader, num_epochs=1)###

def predict(model, dataloader):
    model.eval()
    results = []
    with torch.no_grad():
        for idx, (image, _) in enumerate(dataloader):
            image = image.to(device)
            outputs = model(image)
            outputs = outputs.cpu().numpy()
            for i, output in enumerate(outputs):
                for box in output:
                    cls = int(box[0])
                    x = box[1]
                    y = box[2]
                    results.append([idx, cls, x, y])
    return results

test_dataset = DigitDataset(data_dir='HW6_test', transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

predictions = predict(model, test_loader)

# Convert predictions to dataframe
df = pd.DataFrame(predictions, columns=['IMAGE_index', 'class', 'x', 'y'])
df.to_csv('predictions.csv', index=False)




KeyboardInterrupt: 

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models
# from torch.utils.data import DataLoader, Dataset
# from PIL import Image
# import os
# import pandas as pd
# import numpy as np
# from sklearn.model_selection import train_test_split

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# class DigitDataset(Dataset):
#     def __init__(self, data_dir, transform=None, max_boxes=10):
#         self.data_dir = data_dir
#         self.transform = transform
#         self.images = [img for img in os.listdir(data_dir) if img.endswith('.png')]
#         self.max_boxes = max_boxes
#         self.annotations = self._load_annotations()

#     def _load_annotations(self):
#         annotations = {}
#         for img in self.images:
#             annotation_path = os.path.join(self.data_dir, img.replace('.png', '.txt'))
#             if os.path.exists(annotation_path):
#                 with open(annotation_path, 'r') as f:
#                     boxes = []
#                     for line in f:
#                         boxes.append([float(x) for x in line.strip().split()])
#                     annotations[img] = boxes
#         return annotations

#     def __len__(self):
#         return len(self.images)

#     def __getitem__(self, idx):
#         img_name = self.images[idx]
#         img_path = os.path.join(self.data_dir, img_name)
#         image = Image.open(img_path).convert("RGB")

#         if self.transform:
#             image = self.transform(image)

#         targets = self.annotations.get(img_name, [])
#         # Pad targets to fixed size
#         padded_targets = np.zeros((self.max_boxes, 5))
#         if len(targets) > 0:
#             targets = np.array(targets)
#             padded_targets[:len(targets), :5] = targets

#         padded_targets = torch.tensor(padded_targets, dtype=torch.float32)
#         return image, padded_targets

# train_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# test_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])


# class DigitDetector(nn.Module):
#     def __init__(self, num_boxes=10):
#         super(DigitDetector, self).__init__()
#         self.num_boxes = num_boxes
#         self.model = models.resnet18(pretrained=True)
#         self.model.fc = nn.Linear(self.model.fc.in_features, num_boxes * 5)  # 5 predictions (class + 4 coordinates) per box

#     def forward(self, x):
#         return self.model(x).view(-1, self.num_boxes, 5)

# model = DigitDetector().to(device)


# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)


# def train_model(model, criterion, optimizer, dataloader, num_epochs):
#     model.train()
#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         for i, (images, targets) in enumerate(dataloader):
#             images = images.to(device)
#             targets = targets.to(device)  # No need to reshape here

#             optimizer.zero_grad()
#             outputs = model(images)
            
#             # Reshape targets and outputs to match the loss function expectations
#             loss = criterion(outputs, targets)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item() * images.size(0)
#             if i % 10 == 0:  # Display loss every 10 batches
#                 print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

#         epoch_loss = running_loss / len(dataloader.dataset)
#         print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {epoch_loss:.4f}')

# train_dataset = DigitDataset(data_dir='HW6_train', transform=train_transform)
# # train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# train_model(model, criterion, optimizer, train_loader, num_epochs=1)###



# def predict(model, dataloader):
#     model.eval()
#     results = []
#     with torch.no_grad():
#         for idx, image in enumerate(dataloader):
#             image = image.to(device)
#             outputs = model(image)
#             outputs = outputs.cpu().numpy()
#             for output in outputs:
#                 cls = int(output[0])
#                 x = output[1]
#                 y = output[2]
#                 results.append([idx, cls, x, y])
#     return results

# test_dataset = DigitDataset(data_dir='HW6_test', transform=test_transform)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# predictions = predict(model, test_loader)

# # Convert predictions to dataframe
# df = pd.DataFrame(predictions, columns=['IMAGE_index', 'class', 'x', 'y'])
# df.to_csv('predictions.csv', index=False)