In [1]:
import torch
from torch.optim import Adam
import os
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from model import Resnet18
from tcav import TCAV

os.environ["CUDA_VISIBLE_DEVICES"] = "4"
use_gpu = torch.cuda.is_available()
if use_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

image_dataset = datasets.ImageFolder('../data/amazon', data_transforms)
train_size = int(len(image_dataset)*0.8)
train_data, test_data = torch.utils.data.random_split(image_dataset, [train_size, len(image_dataset) - train_size])
trainloader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=8)
testloader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)

model = Resnet18(output_num=31)
model = model.to(device)
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters())

def train():
    best_weights = model.state_dict()
    best_acc = 0.0
    for epoch in range(1, 201):

        # train phase
        model.train()
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            print(outputs.size())
            print(labels.size())
            loss = criterion(outputs, labels)
            loss.backward()
            #print('loss: ',loss.item())
            optimizer.step()
        
        # test phase
        total = 0
        score = 0
        with torch.no_grad():
            model.eval()
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                predicted = outputs.max(dim=1)[1]
                total += labels.size(0)
                score += predicted.eq(labels).sum().item()
        acc = score / total
        print("epoch: {}\tacc: {}".format(epoch, acc))
        if acc > best_acc:
            best_acc = acc
            best_weights = model.state_dict()

    # save model parameters
    torch.save(best_weights, 'resnet18_office.pth')


def validate():
    model.eval()
    model.load_state_dict(torch.load('resnet18_office.pth'))
    # TODO: Create DataLoaders for Broden concepts and train TCAV


In [2]:
train()

RuntimeError: CUDA out of memory. Tried to allocate 92.00 MiB (GPU 0; 11.75 GiB total capacity; 3.70 GiB already allocated; 44.75 MiB free; 26.50 MiB cached)

In [1]:
import torch
from torch.optim import Adam
import os
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from model import Resnet18
from tcav import TCAV

In [2]:
device = torch.device('cpu')
model = Resnet18(output_num=31)
model.to(device)
model.load_state_dict(torch.load('../resnet18_office.pth'))

<All keys matched successfully>

In [8]:
torch.load('../resnet18_office.pth')

OrderedDict([('conv1.weight',
              tensor([[[[-1.2392e-02, -7.3720e-03, -2.6712e-03,  ...,  5.5574e-02,
                          1.6588e-02, -1.3773e-02],
                        [ 9.2389e-03,  8.3260e-03, -1.1101e-01,  ..., -2.7245e-01,
                         -1.2991e-01,  2.4014e-03],
                        [-7.9552e-03,  5.8543e-02,  2.9493e-01,  ...,  5.1915e-01,
                          2.5605e-01,  6.2490e-02],
                        ...,
                        [-2.8052e-02,  1.5821e-02,  7.2720e-02,  ..., -3.3290e-01,
                         -4.2070e-01, -2.5884e-01],
                        [ 3.0059e-02,  4.0795e-02,  6.2639e-02,  ...,  4.1402e-01,
                          3.9335e-01,  1.6510e-01],
                        [-1.4163e-02, -3.8011e-03, -2.4305e-02,  ..., -1.5031e-01,
                         -8.1721e-02, -5.9926e-03]],
              
                       [[-1.2606e-02, -2.7447e-02, -3.5111e-02,  ...,  3.1913e-02,
                          7.4464

In [9]:
torch.load('cpu_resnet18_office.pth')

OrderedDict([('conv1.weight',
              tensor([[[[-1.2392e-02, -7.3720e-03, -2.6712e-03,  ...,  5.5574e-02,
                          1.6588e-02, -1.3773e-02],
                        [ 9.2389e-03,  8.3260e-03, -1.1101e-01,  ..., -2.7245e-01,
                         -1.2991e-01,  2.4014e-03],
                        [-7.9552e-03,  5.8543e-02,  2.9493e-01,  ...,  5.1915e-01,
                          2.5605e-01,  6.2490e-02],
                        ...,
                        [-2.8052e-02,  1.5821e-02,  7.2720e-02,  ..., -3.3290e-01,
                         -4.2070e-01, -2.5884e-01],
                        [ 3.0059e-02,  4.0795e-02,  6.2639e-02,  ...,  4.1402e-01,
                          3.9335e-01,  1.6510e-01],
                        [-1.4163e-02, -3.8011e-03, -2.4305e-02,  ..., -1.5031e-01,
                         -8.1721e-02, -5.9926e-03]],
              
                       [[-1.2606e-02, -2.7447e-02, -3.5111e-02,  ...,  3.1913e-02,
                          7.4464