In [None]:
from collections import OrderedDict, defaultdict
import os
import pickle
import re
import time

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer, TrainingArguments
from datasets import list_metrics, load_metric, Dataset

import emoji

from PIL import Image

from IPython import display

print(torch.__version__)
print(torchvision.__version__)

DATA_DIR = "data"


## Image classifier

As a net I'll use simple ResNet

### Load images

In [None]:
# ImageFolder expects images to be separated by classes
# It's not the case for us, so we'll have to manually update them
with open(f"{DATA_DIR}/labels.pkl", "rb") as f:
    labels = pickle.load(f)
    
no_lbl = set()
for img_path in os.listdir(f"{DATA_DIR}/images"):
    profile_id = img_path.split("_")[0]
    # Move labeled photos into two directories with corresponding labels
    if os.path.isfile(f"{DATA_DIR}/images/{img_path}"):
        try:
            os.renames(f"{DATA_DIR}/images/{img_path}", f"{DATA_DIR}/images/{labels[profile_id]}/{img_path}")
        except KeyError:
            no_lbl.add(profile_id)
            
print(f"{len(no_lbl)} profiles don't have a lable yet")


In [None]:
def load_images(
    data_path: str,
    batch_size: int, 
    shuffle: bool = True,
    num_workers: int = 1,
    transformations: list = None,
    train_size: float = None,
    seed: int = 42
):  
    transforms = torchvision.transforms.Compose(transformations)
    data = torchvision.datasets.ImageFolder(data_path, transforms)
    
    if train_size is not None:
        _train_size = int(len(data) * train_size)
        _val_size = len(data) - _train_size

        print("Train dataset size:", _train_size, "test dataset size:", _val_size)
        
        gen = torch.Generator().manual_seed(seed)
        data_train, data_val = random_split(data, [_train_size, _val_size], generator=gen)

        return (
            DataLoader(data_train, batch_size, shuffle, num_workers=num_workers),
            DataLoader(data_val, batch_size, shuffle, num_workers=num_workers)
        )
    
    return DataLoader(data, batch_size, shuffle, num_workers=num_workers)


In [None]:
# Load images as ImageFolder dataset
# Define the transformations that'll do to images
transforms = [
    torchvision.transforms.Resize((128, 64)), 
    torchvision.transforms.RandomGrayscale(0.3),
    torchvision.transforms.RandomCrop((96, 48), padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
]
BATCH_SIZE = 8

train_images, val_images = load_images(
    f"{DATA_DIR}/images",
    BATCH_SIZE,
    transformations=transforms,
    train_size=0.85
)


### ResNet-50

In [51]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, padding=1):
        super().__init__()

        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=False)
        self.bn_1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_2 = nn.BatchNorm2d(out_channels)

        self.res_con = None
        if stride != 1 and in_channels != out_channels:
            self.res_con = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, X):
        input_ = X
        out = self.bn_1(self.conv_1(X))
        out = self.relu(out)
        out = self.bn_2(self.conv_2(out))

        if self.res_con is not None:
            out += self.res_con(input_)

        out = self.relu(out)
        return out


In [52]:
class ResNet(nn.Module):
    def __init__(self, input_channels, out_channels, layers, num_classes=10):
        self.hidden_size = out_channels

        super().__init__()
        self.conv_1 = nn.Conv2d(
            input_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn_1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.res_1 = self.make_res_block(out_channels, out_channels, layers[0], 1)
        self.res_2 = self.make_res_block(out_channels, 128, layers[1], 2)
        self.res_3 = self.make_res_block(128, 256, layers[2], 2)
        self.res_4 = self.make_res_block(256, 512, layers[3], 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.out = nn.Linear(512, num_classes, bias=False)

    def make_res_block(self, input_channels, output_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(ResBlock(self.hidden_size, output_channels, stride))
            self.hidden_size = output_channels

        return nn.Sequential(*layers)

    def forward(self, X):
        out = self.bn_1(self.conv_1(X))
        out = self.relu(out)

        out = self.res_1(out)
        out = self.res_2(out)
        out = self.res_3(out)
        out = self.res_4(out)

        out = self.avg_pool(out)
        out = self.flat(out)
        out = self.out(out)

        return out


In [53]:
_X = torch.randn((2, 3, 96, 48))
_resnet = ResNet(3, 64, [2, 2, 2, 2], num_classes=2)
print(_resnet(_X).shape)
del _X, _resnet


torch.Size([2, 2])


### Utils

In [54]:
class Plotter:
    def __init__(
        self,
        x_label: str = None,
        y_label: str = None,
        legend: list = None,
        x_lim: list = None,
        y_lim: list = None,
        x_scale: str = "log",
        y_scale: str = "log",
        n_rows: int = 1,
        n_cols: int = 1,
        figsize: tuple = (8, 6)
    ):

        if legend is None:
            legend = []
        self.fig, self.axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        if n_rows * n_cols == 1:
            self.axes = [self.axes,]
        self.config_axes = lambda: self.set_axes(
            self.axes[0], x_label, y_label, x_lim, y_lim, x_scale, y_scale, legend
        )
        self.X, self.Y = None, None

    @staticmethod
    def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
        """Set the axes for matplotlib."""
        axes.set_xlabel(xlabel)
        axes.set_ylabel(ylabel)
        axes.set_xscale(xscale)
        axes.set_yscale(yscale)
        axes.set_xlim(xlim)
        axes.set_ylim(ylim)
        if legend:
            axes.legend(legend)
        axes.grid()

    def add(self, x, y):
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y in zip(self.X, self.Y):
            self.axes[0].plot(x, y)
        self.config_axes()

    def plot(self):
        display.clear_output(wait=True)
        display.display(self.fig)


In [55]:
def init_weights_(layer):
    if isinstance(layer, torch.nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode="fan_out")
    elif isinstance(layer, (nn.BatchNorm2d)):
        nn.init.constant_(layer.weight, 1)
        nn.init.constant_(layer.bias, 0)
    elif isinstance(layer, nn.Linear):
        nn.init.normal_(layer.weight)


In [56]:
def try_gpu():
    if torch.cuda.device_count() > 0:
        return torch.device("cuda:0")

    return torch.device("cpu")


### Training loop

In [57]:
def train_model(
    net, train_iter, test_iter, epochs, optim, loss=None, device=None,
    init_weights=False, debug=False, save_model=None,
    verbose_interval=5, scheduler=None, clip_grad=False
    ):
    # Init stuff
    if init_weights:
        net.apply(init_weights_)

    loss = torch.nn.CrossEntropyLoss() if not loss else loss
    plotter = Plotter(
        x_label="epochs", 
        y_label="loss", 
        x_lim=[1, epochs], 
        legend=["train loss", "test loss"]
    )

    num_batches = len(train_iter)

    device = try_gpu() if not device else device
    print(f"Training on {device}")
    net.to(device)

    # Training loop
    for epoch in range(epochs):
        net.train()
        train_loss_ = []
        
        for i, (X, y) in enumerate(tqdm(train_iter)):
            optim.zero_grad()
            X, y = X.to(device), y.to(device)
            y_pred = net(X)
            l = loss(y_pred, y)
            l.backward()
            train_loss_.append(l.item())
            
            if clip_grad:
                torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=2., norm_type=2)
            optim.step()

            if debug:
                break

        with torch.no_grad():
            net.eval()
            
            test_loss_ = []
            for X_test, y_test in test_iter:
                X_test, y_test = X_test.to(device), y_test.to(device)
                pred_test = net(X_test)
                test_loss_.append(loss(pred_test, y_test).item())
            
            train_loss = np.mean(train_loss_)
            test_loss = np.mean(test_loss_)
            
            plotter.add(epoch + 1, (train_loss, test_loss))

        if (epoch + 1) % verbose_interval == 0 or epoch == 0 or epoch == (epochs - 1):
                plotter.plot()
                print(
                    f"epoch: {epoch}", f'train loss: {train_loss:.3f}, test loss: {test_loss:.3f}, '
                    f"lr: {optim.param_groups[0]['lr']:.5f}"
                )

        if debug:
            break
        
        if scheduler is not None:
            scheduler.step()

    if save_model is not None:
        torch.save(net.state_dict(), save_model)
        print(f"Model saved to {save_model}")


In [59]:
torch.save(resnet_50.state_dict(), 
           os.path.join("/Users/gleb/Documents/GitHub/tinder_swiper/data/models", "image_model.pt"))


In [58]:
resnet_50 = ResNet(3, 64, [2, 2, 2, 2], num_classes=2)
EPOCHS = 100
lr = 0.1
WD = 5e-4
optim = torch.optim.AdamW(resnet_50.parameters(), lr, weight_decay=WD)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS, eta_min=1e-4)

train_model(resnet_50, train_images, val_images, EPOCHS, optim, init_weights=True, scheduler=scheduler, debug=True)


## Bio classifier

### Load and clean bios

In [None]:
with open(f"{DATA_DIR}/bios.pkl", "rb") as f:
    bios = pickle.load(f)

print(len(bios))
dataset_dict = defaultdict(list)
unlabeled = []

for k, v in bios.items():
    if k in labels:
        dataset_dict["id"].append(k)
        dataset_dict["label"].append(labels[k])
        for k_, v_ in v.items():
            dataset_dict[k_].append(v_)
    else:
        unlabeled.append(k)

print(f"There are {len(unlabeled)} unlabeled profiles")
        
bios_dataset = Dataset.from_dict(dataset_dict)
bios_dataset


In [None]:
def clean_bios(bios, bio_col="bio"):
    def _cleaner(text):
        flags = re.findall(u'[\U0001F1E6-\U0001F1FF]', text)
        
        text = text.replace("\n", " ")
        text = "".join(w for w in text if w not in emoji.EMOJI_DATA and w not in flags)
        text = text.replace("\u200d", "")

        return text

    return {f"{bio_col}_clean": _cleaner(bios[bio_col])}



In [None]:
bios_dataset = bios_dataset.map(clean_bios, batched=False)


### DistilBert

In [None]:
def tokenize(dataset, tokenizer, text_col="bio_clean"):
    return tokenizer(dataset[text_col], padding="max_length", truncation=True, max_length=512)

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
bios_dataset = bios_dataset.map(tokenize, fn_kwargs={"tokenizer": tokenizer}, batched=True)
bios_dataset


In [None]:
bert = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-multilingual-cased",
    num_labels=2,
    torch_dtype=torch.float32
)
bert


In [None]:
layers_to_train = [
    "distilbert.transformer.layer.5.ffn.lin1.weight",
    "distilbert.transformer.layer.5.ffn.lin1.bias",
    "distilbert.transformer.layer.5.ffn.lin2.weight",
    "distilbert.transformer.layer.5.ffn.lin2.bias",
    "distilbert.transformer.layer.5.output_layer_norm.weight",
    "distilbert.transformer.layer.5.output_layer_norm.bias",
    "pre_classifier.weight",
    "pre_classifier.bias",
    "classifier.weight",
    "classifier.bias",
]

for n, l in bert.named_parameters():
    if n not in layers_to_train:
        l.requires_grad = False


In [None]:
bios_dataset = bios_dataset.train_test_split(test_size=0.15)
text_dataset = (
    bios_dataset.rename_column("bio_clean", "text")
    .remove_columns(["id", "name", "bio", "age"])
)
text_dataset


In [None]:
def compute_metrics(eval_pred, metric=None):
    metric = load_metric("accuracy") if metric is None else metric
    logits, labels = eval_pred
    predictions = np.array([np.argmax(logits, axis=-1)]).flatten()
    return metric.compute(references=labels, prediction_scores=predictions)


In [None]:
training_args = TrainingArguments(
    output_dir="./test_trainer",
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    eval_steps=2_500,
    per_device_train_batch_size=16,
    weight_decay=3e-10,
    num_train_epochs=4,
    lr_scheduler_type="constant_with_warmup",
    save_steps=2_500,
    max_steps=12_500
)

# Create training pipeline
trainer = Trainer(
    model=bert,
    args=training_args,
    train_dataset=bios_dataset["train"],
    eval_dataset=bios_dataset["test"],
    compute_metrics=compute_metrics
)

trainer.train()


ValueError: Column name nanan not in the dataset. Current columns in the dataset: ['id', 'label', 'name', 'bio', 'age', 'bio_clean', 'input_ids', 'attention_mask']