In [2]:
import os, sys
import torch
import torch.nn as nn

from torchvision.datasets import CIFAR10
from torchvision import models, transforms
from greenformer import auto_fact
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

In [3]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [4]:
class CNNModel(nn.Module):
    def __init__(self, model, latent_features, out_features):
        super().__init__()
        self.latent_features = latent_features
        self.out_features = out_features
        self.model = nn.Sequential(model, nn.Linear(latent_features, out_features))

    def forward(self, inputs, labels=None, *args, **kwargs):
        if inputs.shape[1] == 1:
            inputs = inputs.repeat(1, 3, 1, 1)

        logits = self.model(inputs)

        outputs = (logits,)
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs

In [5]:
cnn_module = models.resnext50_32x4d(pretrained=True)
cnn_module.fc = nn.Dropout2d(0.1)
model = CNNModel(cnn_module, latent_features=2048, out_features=10)

In [6]:
count_param(model)

23000394

# Apply Factorization-by-design

In [7]:
%%time
factorized_submodules = [model.model[0].layer3, model.model[0].layer4]
fact_model = auto_fact(model, rank=0.5, deepcopy=True, solver='random', num_iter=20, submodules=factorized_submodules)
print(count_param(fact_model))

12889418
CPU times: user 1.87 s, sys: 24 ms, total: 1.89 s
Wall time: 285 ms




# Speed test on CPU

### Test Inference CPU

In [8]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(8,3,224,224, dtype=torch.float))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


364 ms ± 143 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(8,3,224,224, dtype=torch.float))

293 ms ± 44.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [10]:
%%timeit
y = model(torch.zeros(8,3,224,224, dtype=torch.float))
y[0].sum().backward()

1.76 s ± 81.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%%timeit
y = fact_model(torch.zeros(8,3,224,224, dtype=torch.float))
y[0].sum().backward()

2 s ± 95.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [12]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [13]:
x = torch.zeros(64,3,224,224, dtype=torch.float).cuda()

In [14]:
%%timeit
with torch.no_grad():
    y = model(x)

The slowest run took 13.47 times longer than the fastest. This could mean that an intermediate result is being cached.
148 ms ± 55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

160 ms ± 224 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Test Forward-Backward GPU

In [16]:
x = torch.zeros(64,3,224,224, dtype=torch.float).cuda()

In [17]:
%%timeit
y = model(x)
y[0].sum().backward()

656 ms ± 3.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
%%timeit
y = fact_model(x)
y[0].sum().backward()

631 ms ± 4.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Prepare Dataset and DataLoader

In [19]:
# CIFAR10 Dataset
class CIFAR10Dataset(Dataset):
    # Static constant variable
    NUM_LABELS = 10

    def __init__(self, data_split, *args, **kwargs):
        self.data_split = data_split
        if data_split == 'train':
            transformations = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(), 
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616))
            ])
            self.dataset = CIFAR10('./cifar10', download=True, train=True, transform=transformations)
            self.dataset.data = self.dataset.data[:-1000]
            self.dataset.targets = self.dataset.targets[:-1000]
        elif data_split == 'validation':
            transformations = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616))])
            self.dataset = CIFAR10('./cifar10', download=True, train=True, transform=transformations)
            self.dataset.data = self.dataset.data[-1000:]
            self.dataset.targets = self.dataset.targets[-1000:]
        elif data_split == 'test':
            transformations = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616))])
            self.dataset = CIFAR10('./cifar10', download=True, train=False, transform=transformations)
        else:
            raise ValueError(f'Invalid dataset split: `{data_split}`')

    def __getitem__(self, index):
        image, label = self.dataset[index]
        return image, label

    def __len__(self):
        return len(self.dataset)


In [20]:
train_dataset, valid_dataset, test_dataset = CIFAR10Dataset('train'), CIFAR10Dataset('validation'), CIFAR10Dataset('test')
train_loader = DataLoader(train_dataset, batch_size=256, num_workers=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=256, num_workers=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, num_workers=8, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10
Files already downloaded and verified
Files already downloaded and verified


# Run Training & Evaluation

In [21]:
# Forward function for image classification
def forward_image_classification(model, batch_data, device='cpu', **kwargs):
    # Unpack batch data
    input_batch, label_batch = batch_data

    # Prepare input & label
    if device == "cuda":
        input_batch = input_batch.cuda()
        label_batch = label_batch.cuda()

    # Forward model
    outputs = model(input_batch, labels=label_batch)
    loss, logits = outputs[:2]

    # generate prediction & label list
    list_hyp = []
    list_label = []
    hyp = torch.topk(logits, 1)[1]
    for j in range(len(hyp)):
        list_hyp.append(int(hyp[j].item()))
        list_label.append(int(label_batch[j].item()))

    return loss, list_hyp, list_label

# Metric function for calculatting Accuracy and F1
def acc_f1_metrics_fn(list_hyp, list_label):
    metrics = {}
    metrics["ACC"] = accuracy_score(list_label, list_hyp)
    metrics["F1"] = f1_score(list_label, list_hyp, average='macro')
    return metrics

In [22]:
###
# modelling functions
###
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.2f}'.format(key, value))
    return ' '.join(string_list)

###
# Training & Evaluation Function
###

# Evaluate function for validation and test
def evaluate(model, data_loader, forward_fn, metrics_fn, is_test=False, device='cpu'):
    model.eval()
    torch.set_grad_enabled(False)

    total_loss = 0
    list_hyp, list_label = [], []

    pbar = tqdm(iter(data_loader), leave=True, total=len(data_loader))
    for i, batch_data in enumerate(pbar):
        loss, batch_hyp, batch_label = forward_fn(model, batch_data, device=device)

        # Calculate total loss
        test_loss = loss.item()
        total_loss = total_loss + test_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = metrics_fn(list_hyp, list_label)

        if not is_test:
            pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        else:
            pbar.set_description("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))

    if is_test:
        return total_loss, metrics, list_hyp, list_label
    else:
        return total_loss, metrics

# Training function and trainer
def train(model, train_loader, valid_loader, optimizer, forward_fn, metrics_fn, valid_criterion, n_epochs,
              evaluate_every=1, early_stop=3, step_size=1, gamma=0.5, device="cpu"):
    scaler = torch.cuda.amp.GradScaler()
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    best_val_metric = -100
    count_stop = 0

    for epoch in range(n_epochs):
        model.train()
        torch.set_grad_enabled(True)

        total_train_loss = 0
        list_hyp, list_label = [], []

        train_pbar = tqdm(iter(train_loader), leave=True, total=len(train_loader))
        for i, batch_data in enumerate(train_pbar):
            optimizer.zero_grad()
            # Casts operations to mixed precision
            with torch.cuda.amp.autocast():
                loss, batch_hyp, batch_label = forward_fn(model, batch_data, device=device)

                # Scales the loss, and calls backward() to create scaled gradients
                scaler.scale(loss).backward()

                # Unscales the gradients of optimizer's assigned params in-place
                scaler.unscale_(optimizer)

                # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.9)

                # Unscales gradients and calls optimizer.step()
                scaler.step(optimizer)

                # Updates the scale for next iteration
                scaler.update()

            tr_loss = loss.item()
            total_train_loss = total_train_loss + tr_loss

            # Calculate metrics
            list_hyp += batch_hyp
            list_label += batch_label

            train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
                total_train_loss/(i+1), get_lr(optimizer)))

        metrics = metrics_fn(list_hyp, list_label)
        print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

        # Decay Learning Rate
        scheduler.step()

        # evaluate
        if ((epoch+1) % evaluate_every) == 0:
            val_loss, val_metrics = evaluate(model, valid_loader, forward_fn, metrics_fn, is_test=False, device=device)

            # Early stopping
            val_metric = val_metrics[valid_criterion]
            if best_val_metric < val_metric:
                best_val_metric = val_metric
                torch.save(model.state_dict(), "./best_model.th")
                count_stop = 0
            else:
                count_stop += 1
                print("count stop:", count_stop)
                if count_stop == early_stop:
                    break

    # Return
    return model

In [23]:
# Train on Original model
model = model.cuda()
optimizer = AdamW(model.parameters(), lr=0.001)
model = train(model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
    forward_fn=forward_image_classification, metrics_fn=acc_f1_metrics_fn, valid_criterion='ACC', 
    n_epochs=5, evaluate_every=1, early_stop=3, step_size=1, gamma=0.9, device='cuda'
)

# Load best model
model.load_state_dict(torch.load("./best_model.th"))

# Evaluation phase
print('=== Evaluation Phase ===')
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, 
        forward_fn=forward_image_classification, metrics_fn=acc_f1_metrics_fn, is_test=True, device='cuda')
print(test_metrics)

del optimizer, model

  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.9)
(Epoch 1) TRAIN LOSS:1.0135 LR:0.00100000: 100%|████████| 192/192 [00:26<00:00,  7.20it/s]

(Epoch 1) TRAIN LOSS:1.0135 ACC:0.66 F1:0.66 LR:0.00100000



VALID LOSS:0.6483 ACC:0.80 F1:0.79: 100%|███████████████████| 4/4 [00:00<00:00, 14.45it/s]
(Epoch 2) TRAIN LOSS:0.6099 LR:0.00090000: 100%|████████| 192/192 [00:26<00:00,  7.19it/s]

(Epoch 2) TRAIN LOSS:0.6099 ACC:0.80 F1:0.80 LR:0.00090000



VALID LOSS:0.6325 ACC:0.80 F1:0.80: 100%|███████████████████| 4/4 [00:00<00:00, 12.40it/s]
(Epoch 3) TRAIN LOSS:0.5011 LR:0.00081000: 100%|████████| 192/192 [00:26<00:00,  7.15it/s]

(Epoch 3) TRAIN LOSS:0.5011 ACC:0.83 F1:0.83 LR:0.00081000



VALID LOSS:0.5500 ACC:0.82 F1:0.82: 100%|███████████████████| 4/4 [00:00<00:00, 13.18it/s]
(Epoch 4) TRAIN LOSS:0.4354 LR:0.00072900: 100%|████████| 192/192 [00:26<00:00,  7.20it/s]

(Epoch 4) TRAIN LOSS:0.4354 ACC:0.85 F1:0.85 LR:0.00072900



VALID LOSS:0.4266 ACC:0.85 F1:0.84: 100%|███████████████████| 4/4 [00:00<00:00, 14.26it/s]
(Epoch 5) TRAIN LOSS:0.3716 LR:0.00065610: 100%|████████| 192/192 [00:26<00:00,  7.16it/s]

(Epoch 5) TRAIN LOSS:0.3716 ACC:0.87 F1:0.87 LR:0.00065610



VALID LOSS:0.4399 ACC:0.86 F1:0.86: 100%|███████████████████| 4/4 [00:00<00:00, 14.19it/s]


=== Evaluation Phase ===


TEST LOSS:0.4443 ACC:0.85 F1:0.85: 100%|██████████████████| 40/40 [00:02<00:00, 17.58it/s]

{'ACC': 0.8486, 'F1': 0.8493313984463514}





In [24]:
# Train on factorized model
fact_model = fact_model.cuda()
optimizer = AdamW(fact_model.parameters(), lr=0.001)
fact_model = train(fact_model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
    forward_fn=forward_image_classification, metrics_fn=acc_f1_metrics_fn, valid_criterion='ACC', 
    n_epochs=5, evaluate_every=1, early_stop=3, step_size=1, gamma=0.9, device='cuda'
)

# Load best model
fact_model.load_state_dict(torch.load("./best_model.th"))

# Evaluation phase
print('=== Evaluation Phase ===')
test_loss, test_metrics, test_hyp, test_label = evaluate(fact_model, data_loader=test_loader, 
        forward_fn=forward_image_classification, metrics_fn=acc_f1_metrics_fn, is_test=True, device='cuda')
print(test_metrics)

del optimizer, fact_model

(Epoch 1) TRAIN LOSS:1.6294 LR:0.00100000: 100%|████████| 192/192 [00:28<00:00,  6.67it/s]

(Epoch 1) TRAIN LOSS:1.6294 ACC:0.37 F1:0.37 LR:0.00100000



VALID LOSS:1.1791 ACC:0.60 F1:0.59: 100%|███████████████████| 4/4 [00:00<00:00, 11.55it/s]
(Epoch 2) TRAIN LOSS:0.8822 LR:0.00090000: 100%|████████| 192/192 [00:28<00:00,  6.70it/s]

(Epoch 2) TRAIN LOSS:0.8822 ACC:0.70 F1:0.70 LR:0.00090000



VALID LOSS:0.7820 ACC:0.74 F1:0.74: 100%|███████████████████| 4/4 [00:00<00:00, 11.42it/s]
(Epoch 3) TRAIN LOSS:0.6562 LR:0.00081000: 100%|████████| 192/192 [00:28<00:00,  6.74it/s]

(Epoch 3) TRAIN LOSS:0.6562 ACC:0.78 F1:0.78 LR:0.00081000



VALID LOSS:0.6813 ACC:0.78 F1:0.77: 100%|███████████████████| 4/4 [00:00<00:00, 12.09it/s]
(Epoch 4) TRAIN LOSS:0.5527 LR:0.00072900: 100%|████████| 192/192 [00:28<00:00,  6.80it/s]

(Epoch 4) TRAIN LOSS:0.5527 ACC:0.82 F1:0.82 LR:0.00072900



VALID LOSS:0.5727 ACC:0.82 F1:0.81: 100%|███████████████████| 4/4 [00:00<00:00, 12.40it/s]
(Epoch 5) TRAIN LOSS:0.4804 LR:0.00065610: 100%|████████| 192/192 [00:28<00:00,  6.77it/s]

(Epoch 5) TRAIN LOSS:0.4804 ACC:0.84 F1:0.84 LR:0.00065610



VALID LOSS:0.4975 ACC:0.83 F1:0.83: 100%|███████████████████| 4/4 [00:00<00:00, 12.93it/s]


=== Evaluation Phase ===


TEST LOSS:0.4961 ACC:0.83 F1:0.83: 100%|██████████████████| 40/40 [00:02<00:00, 14.39it/s]

{'ACC': 0.8281, 'F1': 0.8260633391714105}



