In [75]:
import torch
from torch import nn
from torch.utils.data import Dataset, random_split, DataLoader

import torchvision
from torchvision import models, transforms
from torchvision.io import read_image, ImageReadMode

from tqdm.auto import tqdm
import os
import pandas as pd

In [76]:
# Constants

TRAIN_SPLIT = 0.8
BATCH_SIZE = 32
MODEL_INPUT_H = 224
MODEL_INPUT_W = 224

device = "cuda" if torch.cuda.is_available() else "cpu"

In [77]:
class HTMLElementDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        # Create a dictionary to map tag to number
        self.tags = dict(map(lambda x: (x[1], x[0]), enumerate(pd.unique(self.img_labels['tag']))))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 2])
        image = read_image(img_path, mode=ImageReadMode.RGB).float() / 255
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, self.tags[label]

In [78]:
class HTMLElementClassifier(nn.Module):
    def __init__(self, num_classes):
        super(HTMLElementClassifier, self).__init__()
        self.resnet = models.resnet18(num_classes=num_classes)

    def forward(self, x):
        return self.resnet(x)

In [79]:
transform = transforms.Compose([
    transforms.Resize((MODEL_INPUT_H, MODEL_INPUT_W)),
])

dataset = HTMLElementDataset(annotations_file='data/annotations.csv', img_dir='data/cropped-by-semantic-tag/', transform=transform)

num_samples = len(dataset)
num_train = int(TRAIN_SPLIT * num_samples)
num_test = num_samples - num_train

train_data, test_data = random_split(dataset, [num_train, num_test])

train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

In [80]:
model = HTMLElementClassifier(num_classes=len(dataset.tags))
model.to(device)
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

HTMLElementClassifier(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [81]:
num_epochs = 10

for epoch in range(num_epochs):
    print('Epoch:', epoch)
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

torch.save(model.state_dict(), 'html_classifier_model.pth')

Epoch: 0


  0%|          | 0/333 [00:00<?, ?it/s]



Epoch [1/10], Loss: 1.7031
Epoch: 1


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [2/10], Loss: 1.8602
Epoch: 2


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [3/10], Loss: 1.4339
Epoch: 3


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [4/10], Loss: 1.4687
Epoch: 4


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [5/10], Loss: 1.7144
Epoch: 5


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [6/10], Loss: 1.4562
Epoch: 6


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [7/10], Loss: 1.1432
Epoch: 7


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [8/10], Loss: 0.8755
Epoch: 8


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [9/10], Loss: 0.7459
Epoch: 9


  0%|          | 0/333 [00:00<?, ?it/s]

Epoch [10/10], Loss: 0.6300


In [82]:
model.eval()
with torch.inference_mode(): 
    test_loss = 0.
    test_acc = 0.
    for images, labels in tqdm(test_loader):
        # Send data to GPU
        images, labels = images.to(device), labels.to(device)
        
        test_pred = model(images)
        
        # Calculate loss and accuracy
        test_loss += criterion(test_pred, labels)
        test_acc += float(torch.sum(labels == test_pred.argmax(dim=1)))

    # Adjust metrics and print out
    test_loss /= float(len(test_loader))
    test_acc /= float(len(test_loader))
    print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")

  0%|          | 0/84 [00:00<?, ?it/s]

Test loss: 1.85141 | Test accuracy: 13.05%

