In [None]:
import torch
from loss import CrossEntropyLoss
from ImageDataset import ImageSet, DeviceDataLoader
from model import DeepLabv3
from trainers import fit
from utils import remove_small_artifacts
import matplotlib.image


best_model = torch.load("saved_models/DeepLabv3_plus_2.pt")
best_model.eval()

# I will try to remove the background of the images to make a new dataset that is possibly easier to train on
image_dir = "data/images"
label_dir = "data/labels/person"

MyPeopleSet = ImageSet(image_dir, label_dir)

for index, image_tensor, label in MyPeopleSet:
    image = MyPeopleSet.images[index]
    image_tensor = image_tensor.unsqueeze(0)
    pred_mask = best_model(image_tensor.cuda())
    pred_mask = pred_mask.argmax(dim=1).cpu().numpy()[0]
    pred_mask = remove_small_artifacts(pred_mask)
    image[pred_mask == 0] = [0, 0, 0]

    matplotlib.image.imsave(f"data/no_bg/mask_{(index+1):04d}.png", image)

In [None]:
MyPeopleNoBgSet = ImageSet("data/no_bg", "data/labels/clothes", label_type="multi")

# Set the seed for the random split
torch.manual_seed(42)
train, val, test = torch.utils.data.random_split(MyPeopleNoBgSet, (0.5, 0.1, 0.4))
train_loader = DeviceDataLoader(train, batch_size=8, shuffle=True)
val_loader = DeviceDataLoader(val, batch_size=8, shuffle=False)
test_loader = DeviceDataLoader(test, batch_size=1, shuffle=False)

loss = CrossEntropyLoss()
model = DeepLabv3(7)
if torch.cuda.is_available():
    model.cuda()

fit(
    30,
    torch.optim.Adam,
    model,
    loss,
    0.00005,
    train_loader,
    val_loader,
    torch.optim.lr_scheduler.ReduceLROnPlateau,
)

In [None]:
best_model = torch.load("saved_models/DeepLabv3_7.pt")
best_model.eval()

In [None]:
def calculate_accuracy(model, test_loader):
    """Calculate the accuracy of the model on the test set."""
    correct_pixels = 0
    total_pixels = 0

    for idx, images, labels in test_loader:
        with torch.no_grad():
            predicted_masks = model(images)
            predicted_masks = torch.argmax(predicted_masks, dim=1)
            correct_pixels += torch.sum(predicted_masks == labels).item()
            total_pixels += torch.numel(labels)

    return correct_pixels / total_pixels


accuracy = calculate_accuracy(best_model, test_loader)
print(f"Test Accuracy: {accuracy}")

In [None]:
# Seems that removing the background and retaining only the person in the image has similar results as before