In [14]:
! pip install pillow torch torchvision torchmetrics

1612.58s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Defaulting to user installation because normal site-packages is not writeable
Collecting torchmetrics
  Downloading torchmetrics-1.3.2-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl.metadata (4.7 kB)
Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.2 torchmetrics-1.3.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade 

In [5]:
import torch 
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

### Multi Input Models

In [6]:
class OmniglotDataset(Dataset):

    def __init__(self, transform, samples): 
        self.transform = transform
        self.samples = samples 

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, alphabet, label = self.samples[idx]
        img = Image.open(img_path).convert('L') # returns the gray scale image
        img = self.transform(img)
        return img, alphabet, label

    

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Define sub-networks as sequential models
        self.image_layer = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ELU(),
            nn.Flatten(),
            nn.Linear(16*32*32, 128)
        )
        self.alphabet_layer = nn.Sequential(
            nn.Linear(30, 8),
            nn.ELU(), 
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 + 8, 964), 
        )
        
    def forward(self, x_image, x_alphabet):
		# Pass the x_image and x_alphabet through appropriate layers
        x_image = self.image_layer(x_image)
        x_alphabet = self.alphabet_layer(x_alphabet)
        # Concatenate x_image and x_alphabet
        x = torch.cat((x_image, x_alphabet), dim=1)
        return self.classifier(x)

In [11]:
dataset_train = OmniglotDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
    ]),
    samples='./omniglot_train/',
)

dataloader_train = DataLoader(
    dataset_train, shuffle=True, batch_size=3,
)

In [12]:
net = Net()
criterion = nn.CrossEntropyLoss() #As the classifier is multi class
optimizer = optim.SGD(net.parameters(), lr=0.01)

In [13]:
for epoch in range(10):
    for image, alphabet, label in dataloader_train:
        output = Net(image, alphabet)
        loss = criterion(output, label)
        optim.zero_grad()
        loss.backward()
        optim.step()

ValueError: not enough values to unpack (expected 3, got 1)

### Multi Ouput Model

In [None]:
# Print the sample at index 100
print(samples[100])

# Create dataset_train
dataset_train = OmniglotDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
      	transforms.Resize((64, 64)),
    ]),
    samples=samples,
)

# Create dataloader_train
dataloader_train = DataLoader(
    dataset_train, batch_size=32, shuffle=True,
)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.image_layer = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ELU(),
            nn.Flatten(),
            nn.Linear(16*32*32, 128)
        )
        # Define the two classifier layers
        self.classifier_alpha = nn.Linear(128, 30)
        self.classifier_char = nn.Linear(128, 964)
        
    def forward(self, x):
        x_image = self.image_layer(x)
        # Pass x_image through the classifiers and return both results
        output_alpha = self.classifier_alpha(x_image)
        output_char = self.classifier_char(x_image)
        return output_alpha, output_char

In [17]:
char_weight = 0.5

In [18]:
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.05)

for epoch in range(1):
    for images, labels_alpha, labels_char in dataloader_train:
        optimizer.zero_grad()
        outputs_alpha, outputs_char = net(images)
        # Compute alphabet classification loss
        loss_alpha = criterion(outputs_alpha, labels_alpha)
        # Compute character classification loss
        loss_char = criterion(outputs_char, labels_char)
        # Compute total loss
        loss = ((1 - char_weight) * loss_alpha) + (char_weight * loss_char)
        loss.backward()
        optimizer.step()

ValueError: not enough values to unpack (expected 3, got 1)

#### Evaluation

In [15]:
from torchmetrics import Accuracy

In [16]:
acc_alpha = Accuracy(task ="multiclass", num_classes=30)
acc_char = Accuracy(task='multiclass', num_classes=964)

In [None]:
def evaluate_model(model):
    # Define accuracy metrics
    acc_alpha = Accuracy(task="multiclass", num_classes=30)
    acc_char = Accuracy(task="multiclass", num_classes=964)

    model.eval()
    with torch.no_grad():
        for images, labels_alpha, labels_char in dataloader_test:
            # Obtain model outputs
            outputs_alpha, outputs_char = model(images)
            _, pred_alpha = torch.max(outputs_alpha, 1)
            _, pred_char = torch.max(outputs_char, 1)
			# Update both accuracy metrics
            acc_alpha(pred_alpha, labels_alpha)
            acc_char(pred_char, labels_char)
    
    print(f"Alphabet: {acc_alpha.compute()}")
    print(f"Character: {acc_char.compute()}")