In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading 

In [4]:
import csv
import random
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import List

import torch
from typing import Tuple
from torch import nn, optim
from torch.nn.utils import rnn
from torch.utils import data
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer

### Dataset Preprocessing


In [5]:
NUM_PROC = 4

def tokenize(batch, tokenizer):
	tokenizer = AutoTokenizer.from_pretrained(tokenizer)

	text = tokenizer(
		batch,
		return_attention_mask=False,
		return_token_type_ids=False,
		truncation=True,
	)
	text = {
		"tokens": text["input_ids"]
	}

	return text


def prepare_dataset(
		dataset: str,
		tokenizer: str,
		tokenized_col: str,
		selected_col: List[str] = None,
		num_proc: int = NUM_PROC,
):
	if selected_col is None:
		selected_col = []

	ds = load_dataset(dataset)
	ds = ds.map(
		tokenize,
		fn_kwargs={
			"tokenizer": tokenizer,
		},
		input_columns=[tokenized_col],
		num_proc=num_proc,
		batched=True,
	)

	ds = ds.select_columns(["tokens"] + selected_col)
	return ds

### Gated Architecture

In [6]:
class SoftGate(nn.Module):
	def __init__(
		self,
		embed_dim: int,
	) -> None:
		super(SoftGate, self).__init__()
		bottleneck_dim = max(2, embed_dim // 2)
		self.threshold = 1e-1
		self.fc = nn.Sequential(
			nn.Linear(embed_dim, bottleneck_dim),
			nn.ReLU(),
			nn.Linear(bottleneck_dim, 1),
		)

	def forward(
		self,
		x: torch.Tensor,
		pad: torch.Tensor,
	) -> Tuple[
		torch.Tensor,
		torch.Tensor,
	]:
		batch = x.size(0)
		length = pad.size(1) - pad.sum(dim=1)
		stride = length.max().item()

		y = self.fc(x)
		y = (1 + torch.tanh(10 * y)) / 2

		for i in range(batch):
			n = length[i].item()
			if n == stride:
				continue
			y[i, n:] = 0

		adjust = 1e-5 + self.threshold
		adjust = adjust - y.max(dim=1).values
		adjust = adjust.clamp(min=0).unsqueeze(dim=-1)

		y = y + adjust
		u = x * y
		y = y.squeeze(dim=-1)

		u_mask = (y > self.threshold) & (~pad)
		length = u_mask.sum(dim=1)
		stride = length.max().item()

		size = list(u.size())
		size[1] = stride

		v = torch.zeros(size, dtype=u.dtype, device=u.device)
		v_mask = torch.arange(stride, device=u.device)
		v_mask = v_mask < length.unsqueeze(dim=-1)

		v[v_mask] = u[u_mask]
		v_pad = ~v_mask

		return v, v_pad



class GatedEncoderLayer(nn.Module):
	def __init__(
		self,
		embed_dim: int,
		heads_num: int,
		fc_dim: int,
	):
		super(GatedEncoderLayer, self).__init__()
		self.attention = nn.MultiheadAttention(
			embed_dim=embed_dim,
			num_heads=heads_num,
			batch_first=True,
		)
		self.gate = SoftGate(embed_dim)
		self.norm1 = nn.LayerNorm(embed_dim)
		self.norm2 = nn.LayerNorm(embed_dim)
		self.ff = nn.Sequential(
			nn.Linear(embed_dim, fc_dim),
			nn.SiLU(),
			nn.Dropout(),

			nn.Linear(fc_dim, embed_dim),
			nn.Dropout(),
		)

	def forward(
		self,
		x: torch.Tensor,
		pad: torch.Tensor,
	) -> Tuple[
		torch.Tensor,
		torch.Tensor,
	]:
		y, attn = self.attention(
			query=x,
			key=x,
			value=x,
			key_padding_mask=pad,
			need_weights=False,
			attn_mask=None,
			average_attn_weights=False,
			is_causal=False,
		)

		y = self.norm1(y + x)
		y = self.norm2(y + self.ff(y))
		y, y_pad = self.gate(y, pad)

		return y, y_pad


class GatedEncoder(nn.Module):
	def __init__(
		self,
		gates_num: int,
		embed_dim: int,
		heads_num: int,
		fc_dim: int,
	):
		super(GatedEncoder, self).__init__()
		self.layers = nn.ModuleList([
			GatedEncoderLayer(
				embed_dim=embed_dim,
				heads_num=heads_num,
				fc_dim=fc_dim,
			)
			for _ in range(gates_num)
		])

	def forward(
		self,
		src: torch.Tensor,
		src_pad: torch.Tensor,
	):
		for layer in self.layers:
			src, src_pad = layer(src, src_pad)
		return src, src_pad

### Model

In [7]:
class Classifier(nn.Module):
	def __init__(
		self,
		vocab_size: int,
		embed_dim: int,
		pad_idx: int,

		encoder_gates_num: int,
		encoder_heads_num: int,
		encoder_fc_dim: int,

		class_num: int,
	):
		super(Classifier, self).__init__()
		self.embedding = nn.Embedding(
			num_embeddings=vocab_size,
			embedding_dim=embed_dim,
			padding_idx=pad_idx,
		)
		self.encoder = GatedEncoder(
			gates_num=encoder_gates_num,
			embed_dim=embed_dim,
			heads_num=encoder_heads_num,
			fc_dim=encoder_fc_dim,
		)
		self.classifier = nn.Sequential(
			nn.Dropout(),
			nn.Linear(embed_dim, class_num),
		)

	def forward(
		self,
		src: torch.Tensor,
		src_pad: torch.Tensor,
	) -> Tuple[
		torch.Tensor,
		torch.Tensor,
		float,
	]:
		x = self.embedding(src)
		x, x_pad = self.encoder(x, src_pad)

		z = (x ** 2).mean()
		x = x[:, 0]
		x = self.classifier(x)

		src_len = src_pad.size(1) - src_pad.sum(dim=1)
		x_len = x_pad.size(1) - x_pad.sum(dim=1)
		ratio = (x_len / src_len).mean(dim=0).item()

		return x, z, ratio

### Util functions

In [8]:
def ckpt_save(
	path: Path,
	model: nn.Module,
	optimizer: optim.Optimizer,
):
	ckpt = {
		"model": model.state_dict(),
		"optim": optimizer.state_dict(),
	}
	torch.save(ckpt, path)


def log_save(path: Path, log: List):
	with path.open("w") as file:
		header = log[0].keys()
		writer = csv.DictWriter(file, fieldnames=header, delimiter="\t")
		writer.writeheader()
		writer.writerows(log)

In [9]:
def collate(batch, pad):
	y = torch.tensor([x["label"] >= 3 for x in batch], dtype=torch.float)
	x = [torch.tensor(x["tokens"]) for x in batch]
	x = rnn.pad_sequence(x, batch_first=True, padding_value=pad)

	batch_size = x.size(0)
	seq_len = x.size(1)

	x_pad = torch.ones(batch_size, seq_len, dtype=torch.bool, device=x.device)
	for i in range(batch_size):
		tokens = batch[i]["tokens"]
		length = len(tokens)
		x_pad[i, :length] = False

	return x, x_pad, y

### Epoch run

In [10]:
def epoch_pass(
	epoch: int,

	device: torch.device,
	model: nn.Module,
	criterion: nn.Module,
	loader: data.DataLoader,

	optimizer: optim.Optimizer | None = None,
	ckpt_dir: Path = None,
	ckpt_freq: int = 10,
):
	test = optimizer is None
	if test:
		model.eval()
		torch.set_grad_enabled(False)
	else:
		model.train()
		torch.set_grad_enabled(True)

	loader_len = len(loader)
	loss_total = 0
	log = []

	ckpt_freq = loader_len / ckpt_freq
	ckpt_n = 0

	desc = "test" if test else "train"
	if ckpt_dir:
		ckpt_dir = ckpt_dir / desc / str(epoch)
		ckpt_dir.mkdir(parents=True)

	bar = tqdm(desc=f"{desc} {epoch}", total=loader_len)
	correct_total = 0
	pred_total = 0

	for i, (x, x_pad, y) in enumerate(loader):
		x = x.to(device)
		y = y.to(device)
		x_pad = x_pad.to(device)

		if optimizer is not None:
			optimizer.zero_grad()

		y_pred, z, ratio = model(x, x_pad)
		y_pred = y_pred.squeeze(-1)
		loss = criterion(y_pred, y) + z

		if optimizer is not None:
			loss.backward()
			optimizer.step()
		torch.cuda.empty_cache()

		loss = loss.item()
		loss_total += loss
		loss_avg = loss_total / (i + 1)

		batch_size = y_pred.size(0)
		correct = y_pred.sigmoid().ge(0.5).eq(y).sum().item()

		pred_total += batch_size
		correct_total += correct

		acc = correct / batch_size
		acc_avg = correct_total / pred_total

		log_ent = {
			"ratio": ratio,
			"acc": acc * 100,
			"acc_": acc_avg * 100,
			"loss": loss,
			"loss_": loss_avg,
		}
		log.append(log_ent)

		bar.set_postfix(**log_ent)
		bar.update(1)

		if ckpt_dir and (i + 1) >= round(ckpt_freq * (ckpt_n + 1)):
			if not test:
				ckpt_path = ckpt_dir / f"{ckpt_n}.pth"
				ckpt_save(ckpt_path, model, optimizer)

			log_path = ckpt_dir / f"{epoch}.tsv"
			log_save(log_path, log)
			ckpt_n += 1
	bar.close()

In [11]:
def main():
	ckpt_dir = datetime.now().strftime("%m%d_%H%M")
	ckpt_dir = Path("checkpoint", ckpt_dir)
	ckpt_dir.mkdir(parents=True, exist_ok=False)

	torch.random.manual_seed(42)
	random.seed(42)

	dataset_name = "Yelp/yelp_review_full"
	tokenizer_name = "google-bert/bert-base-uncased"

	tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
	pad = tokenizer.pad_token_id

	dataset = prepare_dataset(
		dataset=dataset_name,
		tokenizer=tokenizer_name,
		tokenized_col="text",
		selected_col=["label"],
	)

	train_loader = data.DataLoader(
		dataset["train"],
		batch_size=32,
		shuffle=True,
		collate_fn=partial(collate, pad=pad),
	)
	test_loader = data.DataLoader(
		dataset["test"],
		batch_size=256,
		collate_fn=partial(collate, pad=pad),
	)

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	model = Classifier(
		vocab_size=tokenizer.vocab_size,
		pad_idx=tokenizer.pad_token_id,
		embed_dim=64,

		encoder_gates_num=3,
		encoder_heads_num=1,
		encoder_fc_dim=128,

		class_num=1,
	)

	params = sum(p.numel() for p in model.parameters())
	print("Prams", params, "Vocab", tokenizer.vocab_size)
	model.to(device)

	criterion = nn.BCEWithLogitsLoss()
	optimizer = optim.Adam(model.parameters(), lr=0.001)

	epochs = 1
	ckpt_freq = 10

	for i in range(epochs):
		epoch_pass(i, device, model, criterion, train_loader, optimizer, ckpt_dir, ckpt_freq)
		epoch_pass(i, device, model, criterion, test_loader)

In [12]:
main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/6.72k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/299M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/650000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/650000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/50000 [00:00<?, ? examples/s]

Prams 2060228 Vocab 30522


train 0: 100%|██████████| 20313/20313 [17:06<00:00, 19.78it/s, acc=81.2, acc_=85.4, loss=0.261, loss_=0.34, ratio=0.85]
test 0: 100%|██████████| 196/196 [00:24<00:00,  8.06it/s, acc=90, acc_=87.6, loss=0.245, loss_=0.287, ratio=0.881]
