In [1]:
from torchvision.datasets import SUN397
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import torch.nn as nn
import tqdm
import os
from IPython.display import clear_output
import numpy as np

import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import MultiStepLR
import torchvision.models as models
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt



In [2]:
#!g1.1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


## Обучение

In [3]:
import wandb

In [4]:
class SUN397Dataset(Dataset):
    def __init__(self, root, train=True, download=True):
      if train:
        transform = T.Compose([
          T.Resize(256),
          T.RandomHorizontalFlip(),
          T.ToTensor(),
          T.Normalize(mean=[0.485, 0.456, 0.486], std=[0.229 , 0.224, 0.225]),
        ]) 
      else:
          transform = T.Compose([
              T.Resize(256),
              T.ToTensor(),
              T.Normalize(mean=[0.485, 0.456, 0.486], std=[0.229 , 0.224, 0.225]),
        ]) 
      self.tv_dataset = SUN397(
            root, 
            train=train,
            download=download,
            transform=transform,
        )

    def __len__(self):
        return len(self.tv_dataset)

    def __getitem__(self, ix):
        return self.tv_dataset[ix]


In [5]:
def prune_net(net, prune_ratio=0.3, method="l1"):
    parameters_to_prune = [
        (module, "weight") for module in filter(lambda m: type(m) == torch.nn.Conv2d, net.modules())
    ]
    
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=prune_ratio,
    )

def copy_weights_layer(layer_unpruned, layer_pruned):
    with torch.no_grad():
      if "weight" in layer_pruned.state_dict():
        layer_pruned.weight.copy_(layer_unpruned.weight)
      if "weight_orig" in layer_pruned.state_dict():
        layer_pruned.weight_orig.copy_(layer_unpruned.weight)
      if "bias" in layer_pruned.state_dict():
        layer_pruned.bias.copy_(layer_unpruned.bias)
      if "running_mean" in layer_pruned.state_dict():
        layer_pruned.running_mean.copy_(layer_unpruned.running_mean)
      if "running_var" in layer_pruned.state_dict():
        layer_pruned.running_var.copy_(layer_unpruned.running_var)

def copy_weights_net(net_unpruned, net_pruned):
    zipped = zip(net_unpruned.modules(), net_pruned.modules())

    for layer_unpruned, layer_pruned in zipped:
      if "weight" in layer_unpruned.state_dict():
        copy_weights_layer(layer_unpruned, layer_pruned)


In [6]:
#!g1.1
# kwargs = dict(
#     num_classes=397,
# )

# base_weights_net = models.resnet18(**kwargs).to(device)
# torch.save(base_weights_net, "base_weights_net")

In [7]:
# os.mkdir("models")

In [8]:
#!g1.1
def loop_dataloader(dataloader):
    while True:
        for x in iter(dataloader):
            yield x


def train(model, dataloader_train, loss_inst, optimizer, max_iter=10_000,
          dataloader_val=None, val_freq=500, scheduler=None):
    global device
    iterable = loop_dataloader(dataloader_train)
    iterable = tqdm.tqdm(iterable, total=max_iter)
    it = 0
    for X_batch, y_batch in iterable:
        if it == max_iter:
            break
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        logit_batch = model(X_batch)

        loss = loss_inst(logit_batch, y_batch)
        if dataloader_val is not None:
            wandb.log({"loss": loss}, step=it)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if scheduler is not None:
          scheduler.step()

        if it % val_freq == 0 and dataloader_val is not None:
            model.eval()
            is_equal = []

            for X_batch_val, y_batch_val in dataloader_val:
                X_batch_val = X_batch_val.to(device)
                y_batch_val = y_batch_val.to(device)
                is_equal.append(
                    model(X_batch_val).argmax(dim=-1) == y_batch_val
                )

            is_equal_t = torch.cat(is_equal)
            acc = is_equal_t.sum() / len(is_equal_t)
            wandb.log({"accuracy_val": acc}, step=it)
            model.train()

        it += 1
        
def experiment(MAX_ITER, PRUNE_ITER, PRUNE_RATIO , RANDOM_STATE, SCHEDULE, LR, project_name):
    args = {"max-iter" : 30000, "batch-size" : 64, "prune-iter" : 1, "prune-ratio" : 0.2, "prune-method" : "l1", "val-freq" : 250, "random-state" : 1}
    
    
    os.mkdir("models/pr" + str(PRUNE_RATIO) + "_sched" + str(int(SCHEDULE)) + "_lr" + str(LR) + "_iters" + str(MAX_ITER))
    

    args["max-iter"] = MAX_ITER
    args["prune-iter"] = PRUNE_ITER
    args["prune-ratio"] = PRUNE_RATIO
    args["random-state"] = RANDOM_STATE
    
    wandb.init(
        project=project_name,
        entity="bspanfilov",
        force=True,
        name="pr" + str(PRUNE_RATIO) + "_sched" + str(int(SCHEDULE)) + "_lr" + str(LR),
        config=args,
    )
    wandb.define_metric("accuracy_val", summary="max")

    dataset_train = SUN397Dataset(
        "data",
        train=True,
        download=True,
    )
    dataset_val = SUN397Dataset(
        "data",
        train=False,
        download=True,
    )

    if args["random-state"] is not None:
        torch.manual_seed(args["random-state"])

    dataloader_train = DataLoader(
        dataset_train, batch_size=args["batch-size"], shuffle=True
    )
    dataloader_val = DataLoader(
        dataset_val, batch_size=args["batch-size"], shuffle=True
    )

    kwargs = dict(
        num_classes=397,
    )

    net = torch.load("base_weights_net").to(device)
    net_copy = models.resnet18(**kwargs).to(device)
    net_copy.load_state_dict(net.state_dict())


    loss_inst = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if SCHEDULE:
      scheduler = MultiStepLR(optimizer, milestones=[6000,8000], gamma=0.1)
    else:
      scheduler = None

    # per_round_prune_ratio = 1 - (1 - args["prune-ratio"]) ** (1 / args["prune-iter"])
    # prune_net(net, per_round_prune_ratio)
    # copy_weights_net(net_copy, net)

    if args["prune-ratio"] > 0:
        per_round_prune_ratio = 1 - (1 - args["prune-ratio"]) ** (1 / args["prune-iter"])

        per_round_max_iter = int(args["max-iter"])

        for prune_it in range(args["prune-iter"]):
            train(
                net,
                dataloader_train,
                loss_inst,
                optimizer,
                max_iter=per_round_max_iter,
            )

            torch.save(net, "models/pr" + str(PRUNE_RATIO) + "_sched" + str(int(SCHEDULE)) + "_lr" + str(LR) + "_iters" + str(MAX_ITER) + "/prune_it" + str(prune_it))

            prune_net(net, per_round_prune_ratio)
            copy_weights_net(net_copy, net)

    # Run actual training with a final pruned network
    train(
        net,
        dataloader_train,
        loss_inst,
        optimizer,
        max_iter=args["max-iter"],
        dataloader_val=dataloader_val,
        val_freq=args["val-freq"],
        scheduler=scheduler,
    )

    torch.save(net, "models/pr" + str(PRUNE_RATIO) + "_sched" + str(int(SCHEDULE)) + "_lr" + str(LR) + "_iters" + str(MAX_ITER) + "/pruned")
    
    wandb.finish()



In [10]:
#!g1.1
# тут просто настройки, с которыми я как раз запускал эксперименты
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
MAX_ITERS=[30000]
PRUNE_ITERS=[5]
PRUNE_RATIOS=[0]
RANDOM_STATES=[2] 
SCHEDULE = [False]
LR = [0.03]
for p in zip(MAX_ITERS, PRUNE_ITERS, PRUNE_RATIOS, RANDOM_STATES, SCHEDULE, LR):
  experiment(p[0], p[1], p[2], p[3], p[4], p[5], "SUN397+resnet18")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


cpu


Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/bspanfilov/.local/lib/python3.8/site-packages/wandb/__main__.py", line 1, in <module>
    from wandb.cli import cli
  File "/home/bspanfilov/.local/lib/python3.8/site-packages/wandb/cli/cli.py", line 933, in <module>
    def launch_sweep(
  File "/usr/lib/python3/dist-packages/click/core.py", line 1234, in decorator
    cmd = command(*args, **kwargs)(f)
  File "/usr/lib/python3/dist-packages/click/decorators.py", line 115, in decorator
    cmd = _make_command(f, name, attrs, cls)
  File "/usr/lib/python3/dist-packages/click/decorators.py", line 88, in _make_command
    return cls(name=name or f.__name__.lower().replace('_', '-'),
TypeError: __init__() got an unexpected keyword argument 'no_args_is_help'


ServiceStartProcessError: The wandb service process exited with 1. Ensure that `sys.executable` is a valid python interpreter. You can override it with the `_executable` setting or with the `WANDB__EXECUTABLE` environment variable.