In [76]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [77]:
def load_data(id):

    transform = transforms.Compose([transforms.Grayscale(), 
                                transforms.Resize((64, 64)), 
                                transforms.ToTensor()])
    
    if(id == 1):
        dataset = torchvision.datasets.ImageFolder(root='./client1', transform=transform)
    elif(id == 2):
        dataset = torchvision.datasets.ImageFolder(root='./client2', transform=transform)
    elif(id == 3):
        dataset = torchvision.datasets.ImageFolder(root='./client3', transform=transform)
    elif(id == 4):
        dataset = torchvision.datasets.ImageFolder(root='./client4', transform=transform)
    elif(id == 5):
        dataset = torchvision.datasets.ImageFolder(root='./client5', transform=transform)
    elif(id == 6):
        dataset = torchvision.datasets.ImageFolder(root='./client6', transform=transform)
    elif(id == 7):
        dataset = torchvision.datasets.ImageFolder(root='./client7', transform=transform)
    elif(id == 8):
        dataset = torchvision.datasets.ImageFolder(root='./client8', transform=transform)
    else:
        raise ValueError("Invalid client ID.")

    train_ratio = 0.8
    size = len(dataset)
    train_size = int(train_ratio * size)
    test_size = size - train_size

    generator = torch.Generator().manual_seed(42)
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=generator)

    train_dataset, test_dataset = mappingLable(id, train_dataset, test_dataset)

    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

class MappedDataset(torch.utils.data.Dataset):
    def __init__(self, subset, label_map):
        self.subset = subset
        self.label_map = label_map

    def __len__(self):
        return len(self.subset)
    
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        # Nếu nhãn có trong map thì đổi, ngược lại giữ nguyên
        y = self.label_map.get(y, y)
        return x, y

def mappingLable(id, train, test):
    #Mapped Lables
    lable_map1 = {0: 0, 1: 1, 2: 2}
    lable_map2 = {0: 3, 1: 4, 2: 5}
    lable_map3 = {0: 6, 1: 7, 2: 8}
    lable_map4 = {0: 9, 1: 10, 2: 11}
    lable_map5 = {0: 12, 1: 13, 2: 14}
    lable_map6 = {0: 15, 1: 16, 2: 17}
    lable_map7 = {0: 18, 1: 19, 2: 20}
    lable_map8 = {0: 21, 1: 22, 2: 23, 3: 24}

    if id == 1:
        trainset = MappedDataset(train, label_map=lable_map1)
        testset = MappedDataset(test, label_map=lable_map1)
    elif id == 2:
        trainset = MappedDataset(train, label_map=lable_map2)
        testset = MappedDataset(test, label_map=lable_map2)
    elif id == 3:
        trainset = MappedDataset(train, label_map=lable_map3)
        testset = MappedDataset(test, label_map=lable_map3)
    elif id == 4:
        trainset = MappedDataset(train, label_map=lable_map4)
        testset = MappedDataset(test, label_map=lable_map4)
    elif id == 5:
        trainset = MappedDataset(train, label_map=lable_map5)
        testset = MappedDataset(test, label_map=lable_map5)
    elif id == 6:
        trainset = MappedDataset(train, label_map=lable_map6)
        testset = MappedDataset(test, label_map=lable_map6)
    elif id == 7:
        trainset = MappedDataset(train, label_map=lable_map7)
        testset = MappedDataset(test, label_map=lable_map7)
    elif id == 8:
        trainset = MappedDataset(train, label_map=lable_map8)
        testset = MappedDataset(test, label_map=lable_map8)

    return trainset, testset

In [90]:
trainloader1, testloader1 = load_data(8)

In [91]:
print(testloader1)

<torch.utils.data.dataloader.DataLoader object at 0x0000020CD0737090>


In [92]:
for images, labels in testloader1:
    print(images.shape, labels)
    break

torch.Size([32, 1, 64, 64]) tensor([22, 22, 22, 24, 23, 24, 24, 22, 21, 22, 24, 24, 24, 24, 23, 23, 24, 24,
        24, 23, 22, 22, 24, 22, 24, 21, 21, 22, 21, 24, 24, 24])


In [81]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=25, num_filters_conv1=32, num_filters_conv2=64, fc1_neurons=128, image_size=64):
        super().__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(1, num_filters_conv1, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        image_size = image_size // 2
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(num_filters_conv1, num_filters_conv2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        image_size = image_size // 2
        
        self.fc_layers = nn.Sequential(
            nn.Linear(image_size * image_size * num_filters_conv2, fc1_neurons),
            nn.ReLU(),
            nn.Linear(fc1_neurons, num_classes)
        )

    def forward(self, x):
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = x.view(x.size(0), -1) # Flatten the output for the linear layers
        x = self.fc_layers(x)
        return x

In [82]:
model = SimpleCNN()

In [83]:
model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
lossFn = torch.nn.CrossEntropyLoss().to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model.train()

SimpleCNN(
  (conv_layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layers): Sequential(
    (0): Linear(in_features=16384, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=25, bias=True)
  )
)

In [84]:
image = None

for images, labels in testloader1:
    image = images
    break

In [85]:
print(image.shape)

torch.Size([32, 1, 64, 64])


In [86]:
image = image.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

In [87]:
output = model(image)

In [88]:
output.shape

torch.Size([32, 25])