In [1]:
from tqdm import tqdm
from model import SimpleDLA

In [2]:
from dataset_Generator import DataSetGenerator
from utils import load_config
from torch.utils.data import DataLoader

config = load_config('hyperparameters/cifar10/4cifar10_get_score-FULL.yaml')


In [3]:
train_config = config.get('cifar10',None).train
test_config = config.get('cifar10',None).test
trainset =  DataSetGenerator(train_config)
testset =  DataSetGenerator(test_config)
sampler = None


trainloader =  DataLoader(dataset = trainset,
                                    batch_size = config.batch_size,
                                    # shuffle = True, commented due to sampler (both are mutually exclusive)
                                    num_workers = 12,
                                    pin_memory = True,
                                    drop_last=False, # earlier was true
                                    # sampler = sampler,
                                    shuffle = True

                                    )
    
testloader =  DataLoader(dataset = testset,
                                    batch_size = config.batch_size,
                                    shuffle = True,
                                    num_workers = 12,
                                    pin_memory = True,
                                    drop_last=False # earlier was true
                                    )




class wise count for train data- [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
class wise count for test data- [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]


In [4]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = SimpleDLA()
net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

    print(acc)


for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch)
    scheduler.step()

==> Building model..

Epoch: 0


196it [00:18, 10.38it/s]
40it [00:01, 24.03it/s]


Saving..
38.68

Epoch: 1


196it [00:17, 11.06it/s]
40it [00:01, 29.14it/s]


Saving..
55.56

Epoch: 2


196it [00:18, 10.57it/s]
40it [00:01, 29.56it/s]


Saving..
71.85

Epoch: 3


196it [00:18, 10.59it/s]
40it [00:01, 30.20it/s]


Saving..
76.0

Epoch: 4


196it [00:18, 10.43it/s]
40it [00:01, 25.68it/s]


Saving..
76.36

Epoch: 5


196it [00:18, 10.61it/s]
40it [00:01, 29.47it/s]


Saving..
79.22

Epoch: 6


196it [00:18, 10.61it/s]
40it [00:01, 27.68it/s]


Saving..
79.93

Epoch: 7


196it [00:18, 10.48it/s]
40it [00:01, 29.98it/s]


Saving..
83.02

Epoch: 8


196it [00:18, 10.61it/s]
40it [00:01, 30.20it/s]

82.94

Epoch: 9



196it [00:18, 10.67it/s]
40it [00:01, 29.39it/s]


Saving..
83.7

Epoch: 10


196it [00:17, 11.00it/s]
40it [00:01, 24.92it/s]

81.32

Epoch: 11



196it [00:17, 10.96it/s]
40it [00:01, 29.41it/s]

83.55

Epoch: 12



196it [00:17, 10.98it/s]
40it [00:01, 31.06it/s]


Saving..
84.16

Epoch: 13


196it [00:17, 11.04it/s]
40it [00:01, 30.15it/s]


Saving..
86.19

Epoch: 14


196it [00:17, 11.00it/s]
40it [00:01, 26.97it/s]

84.69

Epoch: 15



196it [00:17, 11.08it/s]
40it [00:01, 30.64it/s]

86.15

Epoch: 16



196it [00:17, 11.05it/s]
40it [00:01, 30.60it/s]

85.97

Epoch: 17



196it [00:17, 11.02it/s]
40it [00:01, 30.72it/s]

83.89

Epoch: 18



196it [00:17, 11.09it/s]
40it [00:01, 30.36it/s]


Saving..
87.08

Epoch: 19


196it [00:17, 11.06it/s]
40it [00:01, 31.01it/s]

86.51

Epoch: 20



196it [00:17, 11.05it/s]
40it [00:01, 30.00it/s]

86.29

Epoch: 21



196it [00:17, 11.09it/s]
40it [00:01, 30.84it/s]


Saving..
87.98

Epoch: 22


196it [00:17, 10.97it/s]
40it [00:01, 29.74it/s]

87.41

Epoch: 23



196it [00:17, 11.05it/s]
40it [00:01, 31.15it/s]

87.96

Epoch: 24



196it [00:17, 11.07it/s]
40it [00:01, 30.58it/s]


Saving..
88.4

Epoch: 25


196it [00:17, 11.05it/s]
40it [00:01, 30.16it/s]

84.78

Epoch: 26



196it [00:17, 11.09it/s]
40it [00:01, 30.25it/s]

87.35

Epoch: 27



196it [00:17, 11.08it/s]
40it [00:01, 30.74it/s]


Saving..
88.48

Epoch: 28


196it [00:17, 11.01it/s]
40it [00:01, 29.56it/s]

87.36

Epoch: 29



196it [00:17, 11.09it/s]
40it [00:01, 30.70it/s]

88.05

Epoch: 30



196it [00:17, 11.05it/s]
40it [00:01, 30.86it/s]


Saving..
88.99

Epoch: 31


196it [00:17, 11.00it/s]
40it [00:01, 27.73it/s]


Saving..
89.0

Epoch: 32


196it [00:17, 10.99it/s]
40it [00:01, 30.07it/s]

88.74

Epoch: 33



196it [00:17, 11.02it/s]
40it [00:01, 29.31it/s]

88.21

Epoch: 34



196it [00:17, 10.96it/s]
40it [00:01, 30.32it/s]


Saving..
90.1

Epoch: 35


196it [00:17, 11.02it/s]
40it [00:01, 30.53it/s]

88.92

Epoch: 36



196it [00:17, 11.01it/s]
40it [00:01, 30.51it/s]

88.84

Epoch: 37



196it [00:17, 11.01it/s]
40it [00:01, 29.04it/s]

89.02

Epoch: 38



196it [00:17, 11.02it/s]
40it [00:01, 30.75it/s]

89.14

Epoch: 39



196it [00:17, 10.91it/s]
40it [00:01, 28.43it/s]

90.07

Epoch: 40



196it [00:17, 10.94it/s]
40it [00:01, 30.69it/s]

88.16

Epoch: 41



196it [00:17, 10.98it/s]
40it [00:01, 29.43it/s]

89.56

Epoch: 42



196it [00:17, 11.03it/s]
40it [00:01, 30.32it/s]

89.42

Epoch: 43



196it [00:17, 11.01it/s]
40it [00:01, 30.95it/s]

89.98

Epoch: 44



196it [00:17, 11.01it/s]
40it [00:01, 29.43it/s]


Saving..
90.42

Epoch: 45


196it [00:17, 11.02it/s]
40it [00:01, 29.55it/s]

89.34

Epoch: 46



196it [00:17, 10.98it/s]
40it [00:01, 30.37it/s]

89.84

Epoch: 47



196it [00:17, 10.99it/s]
40it [00:01, 29.44it/s]

89.97

Epoch: 48



196it [00:17, 11.03it/s]
40it [00:01, 29.74it/s]

89.43

Epoch: 49



196it [00:17, 11.03it/s]
40it [00:01, 30.07it/s]

90.12





In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

device = 'cuda:1'

# Data loading and transformation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Model definition (example using a simplified ResNet block)
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, num_blocks=[2,2,2,2], num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

net = ResNet() # Example ResNet18-like architecture
net = net.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
for epoch in range(100):  # Adjust number of epochs as needed
    for i, data in tqdm(enumerate(trainloader, 0)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluation on test set (simplified)
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Epoch {epoch+1}, Accuracy: {accuracy:.2f}%')

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


391it [00:11, 33.49it/s]
100%|██████████| 100/100 [00:01<00:00, 61.96it/s]

Epoch 1, Accuracy: 58.90%



391it [00:11, 33.28it/s]
100%|██████████| 100/100 [00:01<00:00, 66.46it/s]

Epoch 2, Accuracy: 70.10%



391it [00:11, 33.44it/s]
100%|██████████| 100/100 [00:01<00:00, 61.98it/s]

Epoch 3, Accuracy: 75.12%



391it [00:11, 33.01it/s]
100%|██████████| 100/100 [00:01<00:00, 74.12it/s]

Epoch 4, Accuracy: 78.87%



391it [00:11, 32.77it/s]
100%|██████████| 100/100 [00:01<00:00, 68.36it/s]

Epoch 5, Accuracy: 80.35%



391it [00:11, 32.77it/s]
100%|██████████| 100/100 [00:01<00:00, 69.27it/s]

Epoch 6, Accuracy: 81.17%



391it [00:11, 33.53it/s]
100%|██████████| 100/100 [00:01<00:00, 83.95it/s]

Epoch 7, Accuracy: 83.99%



391it [00:11, 33.16it/s]
100%|██████████| 100/100 [00:01<00:00, 73.15it/s]

Epoch 8, Accuracy: 84.71%



391it [00:13, 29.49it/s]
100%|██████████| 100/100 [00:01<00:00, 63.38it/s]

Epoch 9, Accuracy: 85.57%



391it [00:11, 33.47it/s]
100%|██████████| 100/100 [00:01<00:00, 67.29it/s]


Epoch 10, Accuracy: 86.29%


391it [00:11, 33.24it/s]
100%|██████████| 100/100 [00:01<00:00, 62.14it/s]


Epoch 11, Accuracy: 86.64%


391it [00:11, 33.43it/s]
100%|██████████| 100/100 [00:01<00:00, 66.45it/s]

Epoch 12, Accuracy: 86.80%



391it [00:11, 33.25it/s]
100%|██████████| 100/100 [00:01<00:00, 70.45it/s]

Epoch 13, Accuracy: 86.20%



391it [00:11, 33.69it/s]
100%|██████████| 100/100 [00:02<00:00, 49.23it/s]

Epoch 14, Accuracy: 87.98%



391it [00:11, 32.68it/s]
100%|██████████| 100/100 [00:01<00:00, 61.82it/s]

Epoch 15, Accuracy: 88.79%



391it [00:11, 33.63it/s]
100%|██████████| 100/100 [00:01<00:00, 81.58it/s]

Epoch 16, Accuracy: 89.04%



391it [00:11, 32.85it/s]
100%|██████████| 100/100 [00:01<00:00, 51.47it/s]

Epoch 17, Accuracy: 87.91%



391it [00:11, 33.44it/s]
100%|██████████| 100/100 [00:01<00:00, 62.82it/s]

Epoch 18, Accuracy: 87.99%



391it [00:11, 33.10it/s]
100%|██████████| 100/100 [00:02<00:00, 41.01it/s]

Epoch 19, Accuracy: 88.97%



391it [00:12, 31.17it/s]
100%|██████████| 100/100 [00:02<00:00, 48.80it/s]

Epoch 20, Accuracy: 88.60%



391it [00:12, 32.37it/s]
100%|██████████| 100/100 [00:01<00:00, 67.25it/s]

Epoch 21, Accuracy: 89.19%



391it [00:11, 32.73it/s]
100%|██████████| 100/100 [00:01<00:00, 65.60it/s]

Epoch 22, Accuracy: 88.94%



391it [00:11, 33.29it/s]
100%|██████████| 100/100 [00:01<00:00, 60.81it/s]

Epoch 23, Accuracy: 88.92%



391it [00:12, 32.22it/s]
100%|██████████| 100/100 [00:01<00:00, 68.77it/s]

Epoch 24, Accuracy: 89.98%



391it [00:11, 33.24it/s]
100%|██████████| 100/100 [00:01<00:00, 78.55it/s]

Epoch 25, Accuracy: 89.64%



391it [00:11, 33.21it/s]
100%|██████████| 100/100 [00:01<00:00, 77.22it/s]

Epoch 26, Accuracy: 89.02%



391it [00:11, 33.24it/s]
100%|██████████| 100/100 [00:01<00:00, 72.08it/s]

Epoch 27, Accuracy: 89.40%



391it [00:11, 32.85it/s]
100%|██████████| 100/100 [00:01<00:00, 76.29it/s]

Epoch 28, Accuracy: 89.31%



391it [00:11, 32.95it/s]
100%|██████████| 100/100 [00:01<00:00, 65.49it/s]

Epoch 29, Accuracy: 90.08%



391it [00:12, 32.49it/s]
100%|██████████| 100/100 [00:01<00:00, 63.07it/s]

Epoch 30, Accuracy: 90.32%



391it [00:12, 32.25it/s]
100%|██████████| 100/100 [00:01<00:00, 66.10it/s]

Epoch 31, Accuracy: 90.18%



391it [00:12, 31.73it/s]
100%|██████████| 100/100 [00:01<00:00, 71.95it/s]

Epoch 32, Accuracy: 90.24%



391it [00:12, 32.17it/s]
100%|██████████| 100/100 [00:01<00:00, 72.17it/s]

Epoch 33, Accuracy: 90.09%



391it [00:12, 32.18it/s]
100%|██████████| 100/100 [00:01<00:00, 67.32it/s]

Epoch 34, Accuracy: 89.40%



391it [00:12, 32.34it/s]
100%|██████████| 100/100 [00:01<00:00, 63.98it/s]

Epoch 35, Accuracy: 90.36%



391it [00:11, 33.30it/s]
100%|██████████| 100/100 [00:01<00:00, 67.62it/s]


Epoch 36, Accuracy: 90.28%


391it [00:11, 33.23it/s]
100%|██████████| 100/100 [00:02<00:00, 48.44it/s]

Epoch 37, Accuracy: 89.69%



391it [00:12, 32.12it/s]
100%|██████████| 100/100 [00:01<00:00, 65.41it/s]

Epoch 38, Accuracy: 90.14%



391it [00:11, 33.02it/s]
100%|██████████| 100/100 [00:01<00:00, 77.99it/s]

Epoch 39, Accuracy: 90.32%



391it [00:19, 19.92it/s]
100%|██████████| 100/100 [00:02<00:00, 43.36it/s]

Epoch 40, Accuracy: 90.46%



391it [00:14, 26.43it/s]
100%|██████████| 100/100 [00:01<00:00, 80.01it/s]

Epoch 41, Accuracy: 89.64%



391it [00:12, 32.46it/s]
100%|██████████| 100/100 [00:01<00:00, 77.58it/s]

Epoch 42, Accuracy: 90.38%



391it [00:12, 32.49it/s]
100%|██████████| 100/100 [00:01<00:00, 71.99it/s]

Epoch 43, Accuracy: 90.26%



391it [00:12, 31.83it/s]
100%|██████████| 100/100 [00:01<00:00, 50.52it/s]

Epoch 44, Accuracy: 90.25%



391it [00:12, 32.25it/s]
100%|██████████| 100/100 [00:01<00:00, 64.88it/s]

Epoch 45, Accuracy: 90.30%



391it [00:12, 32.20it/s]
100%|██████████| 100/100 [00:01<00:00, 64.92it/s]

Epoch 46, Accuracy: 90.27%



391it [00:12, 32.21it/s]
100%|██████████| 100/100 [00:01<00:00, 57.78it/s]

Epoch 47, Accuracy: 90.66%



391it [00:12, 31.49it/s]
100%|██████████| 100/100 [00:01<00:00, 59.56it/s]

Epoch 48, Accuracy: 90.80%



391it [00:12, 31.65it/s]
100%|██████████| 100/100 [00:01<00:00, 74.84it/s]

Epoch 49, Accuracy: 90.67%



391it [00:12, 31.69it/s]
100%|██████████| 100/100 [00:01<00:00, 75.19it/s]

Epoch 50, Accuracy: 90.13%



391it [00:12, 31.95it/s]
100%|██████████| 100/100 [00:01<00:00, 77.64it/s]

Epoch 51, Accuracy: 90.33%



391it [00:17, 22.25it/s]
100%|██████████| 100/100 [00:02<00:00, 40.61it/s]

Epoch 52, Accuracy: 90.71%



391it [00:16, 23.66it/s]
100%|██████████| 100/100 [00:02<00:00, 48.98it/s]

Epoch 53, Accuracy: 90.95%



391it [00:12, 30.77it/s]
100%|██████████| 100/100 [00:01<00:00, 74.71it/s]

Epoch 54, Accuracy: 89.87%



391it [00:19, 19.76it/s]
100%|██████████| 100/100 [00:02<00:00, 43.29it/s]

Epoch 55, Accuracy: 90.94%



391it [00:15, 25.44it/s]
100%|██████████| 100/100 [00:01<00:00, 70.85it/s]

Epoch 56, Accuracy: 90.86%



391it [00:13, 28.48it/s]
100%|██████████| 100/100 [00:01<00:00, 53.84it/s]

Epoch 57, Accuracy: 91.23%



391it [00:13, 28.34it/s]
100%|██████████| 100/100 [00:01<00:00, 61.19it/s]

Epoch 58, Accuracy: 91.62%



391it [00:13, 28.86it/s]
100%|██████████| 100/100 [00:01<00:00, 64.36it/s]

Epoch 59, Accuracy: 90.00%



391it [00:12, 31.91it/s]
100%|██████████| 100/100 [00:01<00:00, 67.52it/s]

Epoch 60, Accuracy: 90.52%



391it [00:12, 31.52it/s]
100%|██████████| 100/100 [00:02<00:00, 43.89it/s]

Epoch 61, Accuracy: 90.40%



391it [00:23, 16.79it/s]
100%|██████████| 100/100 [00:02<00:00, 44.01it/s]

Epoch 62, Accuracy: 91.12%



391it [00:22, 17.13it/s]
100%|██████████| 100/100 [00:02<00:00, 45.12it/s]

Epoch 63, Accuracy: 90.47%



391it [00:23, 16.70it/s]
100%|██████████| 100/100 [00:02<00:00, 41.26it/s]

Epoch 64, Accuracy: 90.69%



391it [00:23, 16.62it/s]
100%|██████████| 100/100 [00:02<00:00, 43.52it/s]

Epoch 65, Accuracy: 91.21%



391it [00:23, 16.85it/s]
100%|██████████| 100/100 [00:02<00:00, 42.65it/s]

Epoch 66, Accuracy: 91.05%



391it [00:23, 16.71it/s]
100%|██████████| 100/100 [00:02<00:00, 41.85it/s]

Epoch 67, Accuracy: 90.88%



391it [00:23, 16.98it/s]
100%|██████████| 100/100 [00:02<00:00, 42.54it/s]

Epoch 68, Accuracy: 91.37%



391it [00:22, 17.32it/s]
100%|██████████| 100/100 [00:02<00:00, 46.73it/s]

Epoch 69, Accuracy: 90.31%



391it [00:22, 17.74it/s]
100%|██████████| 100/100 [00:02<00:00, 47.71it/s]

Epoch 70, Accuracy: 90.87%



391it [00:22, 17.71it/s]
100%|██████████| 100/100 [00:02<00:00, 46.74it/s]

Epoch 71, Accuracy: 90.63%



391it [00:21, 17.87it/s]
100%|██████████| 100/100 [00:02<00:00, 46.72it/s]

Epoch 72, Accuracy: 90.85%



391it [00:22, 17.71it/s]
100%|██████████| 100/100 [00:02<00:00, 46.68it/s]

Epoch 73, Accuracy: 91.12%



391it [00:22, 17.73it/s]
100%|██████████| 100/100 [00:02<00:00, 46.63it/s]

Epoch 74, Accuracy: 91.49%



391it [00:22, 17.02it/s]
100%|██████████| 100/100 [00:02<00:00, 45.79it/s]

Epoch 75, Accuracy: 91.44%



391it [00:22, 17.15it/s]
100%|██████████| 100/100 [00:02<00:00, 46.48it/s]

Epoch 76, Accuracy: 90.80%



391it [00:21, 18.10it/s]
100%|██████████| 100/100 [00:02<00:00, 49.02it/s]

Epoch 77, Accuracy: 91.34%



391it [00:21, 18.02it/s]
100%|██████████| 100/100 [00:02<00:00, 48.23it/s]

Epoch 78, Accuracy: 91.21%



391it [00:22, 17.68it/s]
100%|██████████| 100/100 [00:02<00:00, 47.20it/s]

Epoch 79, Accuracy: 91.51%



391it [00:21, 18.11it/s]
100%|██████████| 100/100 [00:02<00:00, 47.70it/s]

Epoch 80, Accuracy: 91.46%



391it [00:21, 18.03it/s]
100%|██████████| 100/100 [00:02<00:00, 48.89it/s]

Epoch 81, Accuracy: 90.89%



391it [00:22, 17.63it/s]
100%|██████████| 100/100 [00:02<00:00, 44.25it/s]

Epoch 82, Accuracy: 91.16%



391it [00:21, 17.90it/s]
100%|██████████| 100/100 [00:02<00:00, 48.58it/s]

Epoch 83, Accuracy: 91.62%



391it [00:22, 17.56it/s]
100%|██████████| 100/100 [00:01<00:00, 56.78it/s]

Epoch 84, Accuracy: 91.08%



391it [00:22, 17.43it/s]
100%|██████████| 100/100 [00:01<00:00, 57.41it/s]


Epoch 85, Accuracy: 91.55%


391it [00:22, 17.31it/s]
100%|██████████| 100/100 [00:01<00:00, 54.72it/s]

Epoch 86, Accuracy: 90.94%



391it [00:22, 17.30it/s]
100%|██████████| 100/100 [00:01<00:00, 58.17it/s]

Epoch 87, Accuracy: 91.33%



391it [00:22, 17.33it/s]
100%|██████████| 100/100 [00:01<00:00, 58.23it/s]

Epoch 88, Accuracy: 90.97%



391it [00:22, 17.17it/s]
100%|██████████| 100/100 [00:01<00:00, 61.60it/s]

Epoch 89, Accuracy: 91.34%



391it [00:22, 17.26it/s]
100%|██████████| 100/100 [00:02<00:00, 48.96it/s]

Epoch 90, Accuracy: 91.79%



391it [00:22, 17.74it/s]
100%|██████████| 100/100 [00:02<00:00, 47.56it/s]

Epoch 91, Accuracy: 91.70%



391it [00:22, 17.70it/s]
100%|██████████| 100/100 [00:02<00:00, 48.35it/s]

Epoch 92, Accuracy: 91.31%



391it [00:22, 17.73it/s]
100%|██████████| 100/100 [00:02<00:00, 46.44it/s]

Epoch 93, Accuracy: 91.51%



391it [00:21, 17.77it/s]
100%|██████████| 100/100 [00:02<00:00, 46.63it/s]

Epoch 94, Accuracy: 91.07%



391it [00:21, 17.85it/s]
100%|██████████| 100/100 [00:02<00:00, 46.95it/s]

Epoch 95, Accuracy: 91.57%



391it [00:21, 17.84it/s]
100%|██████████| 100/100 [00:02<00:00, 46.14it/s]

Epoch 96, Accuracy: 91.48%



391it [00:21, 18.17it/s]
100%|██████████| 100/100 [00:02<00:00, 47.30it/s]

Epoch 97, Accuracy: 91.41%



391it [00:22, 17.74it/s]
100%|██████████| 100/100 [00:02<00:00, 47.82it/s]

Epoch 98, Accuracy: 91.13%



391it [00:21, 17.94it/s]
100%|██████████| 100/100 [00:02<00:00, 48.21it/s]

Epoch 99, Accuracy: 91.67%



391it [00:21, 18.19it/s]
100%|██████████| 100/100 [00:02<00:00, 48.17it/s]

Epoch 100, Accuracy: 91.18%
Finished Training





In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model_dla import DLA

device = 'cuda:1'

# Data loading and transformation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# net = ResNet([2, 2, 2, 2]) # Example ResNet18-like architecture
net = DLA()
net = net.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
for epoch in range(20):  # Adjust number of epochs as needed
    for i, data in tqdm(enumerate(trainloader, 0)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluation on test set (simplified)
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Epoch {epoch+1}, Accuracy: {accuracy:.2f}%')

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


391it [00:21, 18.32it/s]
100%|██████████| 100/100 [00:01<00:00, 54.80it/s]


Epoch 1, Accuracy: 51.93%


391it [00:21, 18.47it/s]
100%|██████████| 100/100 [00:01<00:00, 56.26it/s]

Epoch 2, Accuracy: 61.71%



391it [00:21, 18.52it/s]
100%|██████████| 100/100 [00:01<00:00, 51.68it/s]

Epoch 3, Accuracy: 71.02%



391it [00:21, 18.52it/s]
100%|██████████| 100/100 [00:01<00:00, 53.20it/s]


Epoch 4, Accuracy: 74.32%


391it [00:21, 18.53it/s]
100%|██████████| 100/100 [00:01<00:00, 54.80it/s]

Epoch 5, Accuracy: 77.95%



391it [00:21, 18.46it/s]
100%|██████████| 100/100 [00:01<00:00, 55.60it/s]

Epoch 6, Accuracy: 79.41%



391it [00:21, 18.53it/s]
100%|██████████| 100/100 [00:01<00:00, 55.48it/s]

Epoch 7, Accuracy: 81.85%



391it [00:21, 18.11it/s]
100%|██████████| 100/100 [00:01<00:00, 50.07it/s]

Epoch 8, Accuracy: 82.30%



391it [00:21, 18.31it/s]
100%|██████████| 100/100 [00:01<00:00, 54.48it/s]

Epoch 9, Accuracy: 83.23%



391it [00:21, 18.58it/s]
100%|██████████| 100/100 [00:01<00:00, 56.63it/s]

Epoch 10, Accuracy: 82.56%



391it [00:21, 18.59it/s]
100%|██████████| 100/100 [00:01<00:00, 56.21it/s]

Epoch 11, Accuracy: 84.62%



391it [00:21, 18.59it/s]
100%|██████████| 100/100 [00:01<00:00, 56.29it/s]

Epoch 12, Accuracy: 85.28%



391it [00:21, 18.59it/s]
100%|██████████| 100/100 [00:01<00:00, 56.49it/s]

Epoch 13, Accuracy: 86.05%



391it [00:21, 18.56it/s]
100%|██████████| 100/100 [00:01<00:00, 56.76it/s]

Epoch 14, Accuracy: 85.69%



391it [00:21, 18.62it/s]
100%|██████████| 100/100 [00:01<00:00, 56.35it/s]

Epoch 15, Accuracy: 86.47%



391it [00:21, 18.22it/s]
100%|██████████| 100/100 [00:02<00:00, 39.95it/s]

Epoch 16, Accuracy: 86.92%



391it [00:21, 18.05it/s]
100%|██████████| 100/100 [00:02<00:00, 38.81it/s]

Epoch 17, Accuracy: 86.89%



391it [00:22, 17.56it/s]
100%|██████████| 100/100 [00:03<00:00, 28.59it/s]

Epoch 18, Accuracy: 86.41%



391it [00:22, 17.42it/s]
100%|██████████| 100/100 [00:02<00:00, 43.38it/s]

Epoch 19, Accuracy: 87.31%



391it [00:22, 17.76it/s]
100%|██████████| 100/100 [00:01<00:00, 55.22it/s]

Epoch 20, Accuracy: 88.42%
Finished Training



