In [1]:
import qtorch

In [2]:

from torchvision import datasets, transforms
import torch

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

MnistDataset = datasets.MNIST("./data", download=True, train=False, transform=transform)

MnistDatasetTrain =torch.utils.data.Subset(MnistDataset, range(7500)) 
MnistDatasetTest =torch.utils.data.Subset(MnistDataset, range(7500, 10000)) 

In [3]:
from copy import deepcopy
import torch
from torch import nn
import qtorch
from qtorch.quant import Quantizer

fnumber = qtorch.FloatingPoint(5, 10)
bnumber = qtorch.FloatingPoint(5, 10)

def get_network():
    network = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(784, 256),
        torch.nn.Dropout(0.1),
        torch.nn.ReLU(),
        torch.nn.Linear(256, 10),
        torch.nn.Softmax(dim=1)
    )
    return network

low_precision_network = torch.nn.Sequential(
    torch.nn.Flatten(),
    Quantizer(forward_number=fnumber),
    torch.nn.Linear(784, 256),
    Quantizer(backward_number=bnumber),
    torch.nn.ReLU(),
    Quantizer(forward_number=fnumber),
    torch.nn.Linear(256, 10),
    Quantizer(backward_number=bnumber),
    torch.nn.Softmax(dim=1)
)

master_weight = deepcopy(low_precision_network)


class MasterWeightOptimizerWrapper():
    def __init__(
            self,
            master_weight,
            model_weight,
            optimizer,
            weight_quant=None,
            grad_scaling=1.0,
    ):
        self.master_weight = master_weight
        self.model_weight = model_weight
        self.optimizer = optimizer
        self.weight_quant = weight_quant
        self.grad_scaling = grad_scaling

    # --- for mix precision training ---
    def model_grads_to_master_grads(self):
        for model, master in zip(self.model_weight.parameters(), self.master_weight.parameters()):
            if master.grad is None:
                master.grad = master.data.new(*master.data.size())
            master.grad.data.copy_(self.grad_quant(model.grad.data))

    def master_grad_apply(self, fn):
        for master in (self.master_weight.parameters()):
            if master.grad is None:
                master.grad = master.data.new(*master.data.size())
            master.grad.data = fn(master.grad.data)

    def master_params_to_model_params(self):
        for model, master in zip(self.model_weight.parameters(), self.master_weight.parameters()):
            model.data.copy_(self.weight_quant(master.data))

    def _apply_model_weights(self, model, quant_func):
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.weight.data = quant_func(m.weight.data)
                if m.bias is not None:
                    m.bias.data = quant_func(m.bias.data)

    def train_on_loss(self, loss):
        loss = loss * self.loss_scale
        opt = self.optimizer
        self.model_weight.zero_grad()
        self.model_weight.backward(loss)
        self.model_grads_to_master_grads()
        self.master_grad_apply(lambda x: x / self.loss_scale)
        nn.utils.clip_grad_norm_(self.master_weight.parameters(), self.grad_clip)
        opt.step()
        self.master_params_to_model_params()
        self._apply_model_weights(self.model_weight, self.weight_quant)
        return loss.item()

In [6]:
import torch
import torch.nn as nn
import wandb
device = "cuda"

def test(network, dataset):
    network.eval()
    correct = 0
    total_loss = 0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for data, target in dataset:
            data = data.to(device)
            target = target.to(device)
            output = network(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(dataset.dataset)  # ensure dataset.dataset is the correct reference
    avg_loss = total_loss / len(dataset)

    return {"acc": accuracy, "test_loss": avg_loss}

class MovingAvg():
    def __init__(self, beta=0.9):
        self.beta = beta
        self.average = None

    def update(self, value):
        if self.average is None:
            self.average = value
        else:
            self.average = self.beta * self.average + (1 - self.beta) * value

    def get(self):
        return self.average

class MovingAvgStat():
    def __init__(self, beta):
        self.beta = beta
        self.stats = {}

    def add_value(self, stats):
        for key in stats:
            if key not in self.stats:
                self.stats[key] = MovingAvg(self.beta)
            self.stats[key].update(stats[key])

    def get(self):
        return {f"{key}_mov_avg": self.stats[key].get() for key in self.stats}

def report_stats(network):
    stats = {}
    i = 0
    for _, m in enumerate(network.modules()):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            stats[f"{i}_w_norm"] = m.weight.data.norm().item()
            stats[f"{i}_w_mean"] = m.weight.data.mean().item()
            stats[f"{i}_w_std"] = m.weight.data.std().item()
            stats[f"{i}_w_max"] = m.weight.data.max().item()
            stats[f"{i}_w_min"] = m.weight.data.min().item()
            stats[f"{i}_g_norm"] = torch.sqrt((m.weight.grad.data ** 2).sum()).item() if m.weight.grad is not None else 0
            stats[f"{i}_g_mean"] = m.weight.grad.data.mean().item() if m.weight.grad is not None else 0
            stats[f"{i}_g_std"] = m.weight.grad.data.std().item() if m.weight.grad is not None else 0
            i += 1
    return stats

def train(network, dataset, test_dataset, steps, lr=0.01):
    wandb.init(project="convergence_srsgd")
    wandb.watch(network, log='all', log_freq=10)

    network.train()
    optimizer = torch.optim.SGD(network.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    iteration = 0
    loss_sum = 0
    stats_moving_average = MovingAvgStat(0.9)

    while True:
        for data, target in dataset:
            data = data.to(device)
            target = target.to(device)
            if iteration >= steps:
                return network
            iteration += 1
            optimizer.zero_grad()
            output = network(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()

            if iteration % 10 == 0:
                avg_loss = loss_sum / 10
                wandb.log({"iteration_loss": avg_loss})
                loss_sum = 0
                stats = report_stats(network)
                # stats_moving_average.add_value(stats)
                wandb.log({"0_g_norm": stats["0_g_norm"]})
                wandb.log({"1_g_norm": stats["1_g_norm"]})
                # wandb.log({"0_g_norm_mov_avg": stats_moving_average.get()["0_g_norm_mov_avg"]})
                # wandb.log({"1_g_norm_mov_avg": stats_moving_average.get()["1_g_norm_mov_avg"]})

            if iteration % 100 == 0:
                test_metrics = test(network, test_dataset)
                wandb.log(test_metrics)


In [13]:
MnistDatasetTrain.dataset.train_data



tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [8]:
network = get_network()
train_loader = torch.utils.data.DataLoader(MnistDatasetTrain, batch_size=7500, shuffle=True)
test_loader = torch.utils.data.DataLoader(MnistDatasetTest, batch_size=512, shuffle=True)
network = network.to(device)
train(network,train_loader, test_loader, steps=10000)

0,1
0_g_norm,▃▄▇██▆▆▆▄▄▃▄▄▄▃▃▃▃▃▂▂▃▅▂▂▁▄▃▁▂▂▃▄▃▂▂▂▂▂▁
1_g_norm,▂▃▆██▇█▇▅▄▄▆▆▆▃▅▃▃▃▃▃▄▇▄▂▁▆▃▂▂▃▃▄▄▃▃▃▄▃▁
acc,▁▂▃▄▅▆▆▇▇▇▇▇█████████████████████████
iteration_loss,██▇▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss,█▇▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
0_g_norm,0.10888
1_g_norm,0.06026
acc,0.9176
iteration_loss,1.5692
test_loss,1.56501


KeyboardInterrupt: 