In [1]:
from pathlib import Path

from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

current_path = Path.cwd()
dataset_dir = current_path / "modelnet10"

pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=False),
    T.NormalizeScale(),
])

train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)


Bad key "text.kerning_factor" on line 4 in
/opt/conda/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [3]:
print("train_dataset len:", len(train_dataset))
print(train_dataset[0])

In [4]:
print(train_dataset[0].pos.shape)
print(train_dataset[0].pos)

In [5]:
from torch_geometric.data import DataLoader as DataLoader
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
batch = next(iter(dataloader))
print(batch)

In [6]:
from torch_geometric.nn import global_max_pool
import torch.nn as nn

class SymmFunction(nn.Module):
    def __init__(self):
        super(SymmFunction, self).__init__()
        self.shared_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 512),
        )
        
    def forward(self, batch):
        x = self.shared_mlp(batch.pos)
        x = global_max_pool(x, batch.batch)
        return x

f = SymmFunction()
print(batch)
y = f(batch)
print(y.shape)

Batch(batch=[32768], pos=[32768, 3], y=[32])
torch.Size([32, 512])


In [7]:
class InputTNet(nn.Module):
    def __init__(self):
        super(InputTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 9)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 3, 3)
        id_matrix = torch.eye(3).to(x.device).view(1, 3, 3).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [8]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super(FeatureTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 64*64)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 64, 64)
        id_matrix = torch.eye(64).to(x.device).view(1, 64, 64).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [9]:
class PointNetClassification(nn.Module):
    def __init__(self):
        super(PointNetClassification, self).__init__()
        self.input_tnet = InputTNet()
        self.mlp1 = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        self.mlp2 = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.mlp3 = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(256, 10)
        )
        
    def forward(self, batch_data):
        x = batch_data.pos
        
        input_transform = self.input_tnet(x, batch_data.batch)
        transform = input_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 3, 1)).view(-1, 3)
        
        x = self.mlp1(x)
        
        feature_transform = self.feature_tnet(x, batch_data.batch)
        transform = feature_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 64, 1)).view(-1, 64)

        x = self.mlp2(x)        
        x = global_max_pool(x, batch_data.batch)
        x = self.mlp3(x)
        
        return x, input_transform, feature_transform

In [10]:
import torch
from torch.utils.tensorboard import SummaryWriter

num_epoch = 400
batch_size = 32

device = torch.device("cuda:0")
model = PointNetClassification()
model = model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch // 4, gamma=0.5)

log_dir = current_path / "log_modelnet10_classification"
log_dir.mkdir(exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criteria = torch.nn.CrossEntropyLoss()

In [11]:
from tqdm import tqdm

for epoch in tqdm(range(num_epoch)):
    model = model.train()
    
    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.batch.detach().max() + 1
        
        pred_y, _, feature_transform = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
        transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
        reg_loss = transform_norm.mean()

        loss = class_loss + reg_loss * 0.001
        
        losses.append({
            "loss": loss.item(),
            "class_loss": class_loss.item(),
            "reg_loss": reg_loss.item(),
            "accuracy": accuracy,
            "seen": float(this_batch_size)})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    if (epoch % 10 == 0):
        model_path = log_dir / f"model_{epoch:06}.pth"
        torch.save(model.state_dict(), model_path)
    
    loss = 0
    class_loss = 0
    reg_loss = 0
    accuracy = 0
    seen = 0
    for d in losses:
        seen = seen + d["seen"]
        loss = loss + d["loss"] * d["seen"]
        class_loss = class_loss + d["class_loss"] * d["seen"]
        reg_loss = reg_loss + d["reg_loss"] * d["seen"]
        accuracy = accuracy + d["accuracy"] * d["seen"]
    loss = loss / seen
    class_loss = class_loss / seen
    reg_loss = reg_loss / seen
    accuracy = accuracy / seen
    writer.add_scalar("train_epoch/loss", loss, epoch)
    writer.add_scalar("train_epoch/class_loss", class_loss, epoch)
    writer.add_scalar("train_epoch/reg_loss", reg_loss, epoch)
    writer.add_scalar("train_epoch/accuracy", accuracy, epoch)

    with torch.no_grad():
        model = model.eval()

        losses = []
        for batch_data in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_data = batch_data.to(device)
            this_batch_size = batch_data.batch.detach().max() + 1

            pred_y, _, feature_transform = model(batch_data)
            true_y = batch_data.y.detach()

            class_loss = criteria(pred_y, true_y)
            accuracy =float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

            id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
            transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
            reg_loss = transform_norm.mean()

            loss = class_loss + reg_loss * 0.001

            losses.append({
                "loss": loss.item(),
                "class_loss": class_loss.item(),
                "reg_loss": reg_loss.item(),
                "accuracy": accuracy,
                "seen": float(this_batch_size)})
            
        loss = 0
        class_loss = 0
        reg_loss = 0
        accuracy = 0
        seen = 0
        for d in losses:
            seen = seen + d["seen"]
            loss = loss + d["loss"] * d["seen"]
            class_loss = class_loss + d["class_loss"] * d["seen"]
            reg_loss = reg_loss + d["reg_loss"] * d["seen"]
            accuracy = accuracy + d["accuracy"] * d["seen"]
        loss = loss / seen
        class_loss = class_loss / seen
        reg_loss = reg_loss / seen
        accuracy = accuracy / seen
        writer.add_scalar("test_epoch/loss", loss, epoch)
        writer.add_scalar("test_epoch/class_loss", class_loss, epoch)
        writer.add_scalar("test_epoch/reg_loss", reg_loss, epoch)
        writer.add_scalar("test_epoch/accuracy", accuracy, epoch)

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 0/125 [00:00<?, ?it/s][A
  1%|          | 1/125 [00:00<01:05,  1.90it/s][A
  2%|▏         | 2/125 [00:00<00:50,  2.45it/s][A
  2%|▏         | 3/125 [00:00<00:39,  3.13it/s][A
  3%|▎         | 4/125 [00:00<00:31,  3.83it/s][A
  4%|▍         | 5/125 [00:01<00:26,  4.55it/s][A
  5%|▍         | 6/125 [00:01<00:22,  5.25it/s][A
  6%|▌         | 7/125 [00:01<00:20,  5.90it/s][A
  6%|▋         | 8/125 [00:01<00:18,  6.40it/s][A
  7%|▋         | 9/125 [00:01<00:17,  6.82it/s][A
  8%|▊         | 10/125 [00:01<00:16,  7.12it/s][A
  9%|▉         | 11/125 [00:01<00:15,  7.45it/s][A
 10%|▉         | 12/125 [00:01<00:14,  7.65it/s][A
 10%|█         | 13/125 [00:02<00:14,  7.78it/s][A
 11%|█         | 14/125 [00:02<00:14,  7.83it/s][A
 12%|█▏        | 15/125 [00:02<00:13,  8.08it/s][A
 13%|█▎        | 16/125 [00:02<00:13,  8.08it/s][A
 14%|█▎        | 17/125 [00:02<00:13,  8.09it/s][A
 14%|█▍        | 18/125 [00:02<00:13,  8.0

KeyboardInterrupt: 

Batch(batch=[32768], pos=[32768, 3], y=[32])


In [1]:
from pathlib import Path

from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader as DataLoader
from torch_geometric.nn import global_max_pool
import torch.nn as nn
import torch
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter


Bad key "text.kerning_factor" on line 4 in
/opt/conda/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [2]:
current_path = Path.cwd()
dataset_dir = current_path / "modelnet10"
log_dir = current_path / "log_modelnet10_classification"

log_dir.mkdir(exist_ok=True)

batch_size = 64

In [3]:
pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=False),
    T.NormalizeScale(),
])

train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)

dataset = train_dataset
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [3]:
class SymmFunction(nn.Module):
    def __init__(self):
        super(SymmFunction, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024),
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        return x



In [4]:
class InputTNet(nn.Module):
    def __init__(self):
        super(InputTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 9)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 3, 3)
        id_matrix = torch.eye(3).to(x.device).view(1, 3, 3).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [5]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super(FeatureTNet, self).__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 64*64)
        )
        
    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 64, 64)
        id_matrix = torch.eye(64).to(x.device).view(1, 64, 64).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [6]:
class PointNetClassification(nn.Module):
    def __init__(self):
        super(PointNetClassification, self).__init__()
        self.input_tnet = InputTNet()
        self.mlp1 = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
        )
        self.feature_tnet = FeatureTNet()
        self.mlp2 = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
        )
        self.mlp3 = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(256, 10)
        )
        
    def forward(self, batch_data):
        x = batch_data.pos
        
        input_transform = self.input_tnet(x, batch_data.batch)
        transform = input_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 3, 1)).view(-1, 3)
        
        x = self.mlp1(x)
        
        feature_transform = self.feature_tnet(x, batch_data.batch)
        transform = feature_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 64, 1)).view(-1, 64)

        x = self.mlp2(x)        
        x = global_max_pool(x, batch_data.batch)
        x = self.mlp3(x)
        
        return x, input_transform, feature_transform

In [7]:
num_epoch = 400

device = torch.device("cuda:0")
model = PointNetClassification()
model = model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch // 4, gamma=0.5)

writer = SummaryWriter(log_dir=log_dir)
#writer.close()

pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=False),
    T.NormalizeScale(),
])

train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criteria = torch.nn.CrossEntropyLoss()

In [8]:
for epoch in range(num_epoch):
#for epoch in tqdm(range(num_epoch)):
    model = model.train()
    
    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.batch.detach().max() + 1
        
        pred_y, _, feature_transform = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
        transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
        reg_loss = transform_norm.mean()

        loss = class_loss + reg_loss * 0.001
        
        losses.append({
            "loss": loss.item(),
            "class_loss": class_loss.item(),
            "reg_loss": reg_loss.item(),
            "accuracy": accuracy,
            "seen": float(this_batch_size)})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    if (epoch % 10 == 0):
        model_path = log_dir / f"model_{epoch:06}.pth"
        torch.save(model.state_dict(), model_path)
    
    loss = 0
    class_loss = 0
    reg_loss = 0
    accuracy = 0
    seen = 0
    for d in losses:
        seen = seen + d["seen"]
        loss = loss + d["loss"] * d["seen"]
        class_loss = class_loss + d["class_loss"] * d["seen"]
        reg_loss = reg_loss + d["reg_loss"] * d["seen"]
        accuracy = accuracy + d["accuracy"] * d["seen"]
    loss = loss / seen
    class_loss = class_loss / seen
    reg_loss = reg_loss / seen
    accuracy = accuracy / seen
    writer.add_scalar("train_epoch/loss", loss, epoch)
    writer.add_scalar("train_epoch/class_loss", class_loss, epoch)
    writer.add_scalar("train_epoch/reg_loss", reg_loss, epoch)
    writer.add_scalar("train_epoch/accuracy", accuracy, epoch)

    with torch.no_grad():
        model = model.eval()

        losses = []
        for batch_data in tqdm(test_dataloader, total=len(test_dataloader)):
            batch_data = batch_data.to(device)
            this_batch_size = batch_data.batch.detach().max() + 1

            pred_y, _, feature_transform = model(batch_data)
            true_y = batch_data.y.detach()

            class_loss = criteria(pred_y, true_y)
            accuracy =float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

            id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
            transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))
            reg_loss = transform_norm.mean()

            loss = class_loss + reg_loss * 0.001 * 0.001

            losses.append({
                "loss": loss.item(),
                "class_loss": class_loss.item(),
                "reg_loss": reg_loss.item(),
                "accuracy": accuracy,
                "seen": float(this_batch_size)})
            
        loss = 0
        class_loss = 0
        reg_loss = 0
        accuracy = 0
        seen = 0
        for d in losses:
            seen = seen + d["seen"]
            loss = loss + d["loss"] * d["seen"]
            class_loss = class_loss + d["class_loss"] * d["seen"]
            reg_loss = reg_loss + d["reg_loss"] * d["seen"]
            accuracy = accuracy + d["accuracy"] * d["seen"]
        loss = loss / seen
        class_loss = class_loss / seen
        reg_loss = reg_loss / seen
        accuracy = accuracy / seen
        writer.add_scalar("test_epoch/loss", loss, epoch)
        writer.add_scalar("test_epoch/class_loss", class_loss, epoch)
        writer.add_scalar("test_epoch/reg_loss", reg_loss, epoch)
        writer.add_scalar("test_epoch/accuracy", accuracy, epoch)


100%|██████████| 63/63 [02:03<00:00,  1.97s/it]
100%|██████████| 15/15 [00:05<00:00,  2.61it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.59it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.58it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.57it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.58it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.59it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.56it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.57it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.56it/s]
100%|██████████| 63/63 [01:59<00:00,  1.90s/it]
100%|██████████| 15/15 [00:05<00:00,  2.57it/s]
100%|██████████| 63/63 [01:59<00:00,  1.

In [9]:
(pred_y.argmax(dim=1) == true_y).sum().item() /2

5.5

In [10]:
writer.close()

In [11]:
losses

[{'loss': 0.4300372302532196,
  'class_loss': 0.4299507737159729,
  'reg_loss': 86.45646667480469,
  'accuracy': 0.859375,
  'seen': 64.0},
 {'loss': 0.1293186992406845,
  'class_loss': 0.12924015522003174,
  'reg_loss': 78.54734802246094,
  'accuracy': 0.984375,
  'seen': 64.0},
 {'loss': 0.20876644551753998,
  'class_loss': 0.20869794487953186,
  'reg_loss': 68.50424194335938,
  'accuracy': 0.9375,
  'seen': 64.0},
 {'loss': 0.2231782227754593,
  'class_loss': 0.2231127917766571,
  'reg_loss': 65.43126678466797,
  'accuracy': 0.953125,
  'seen': 64.0},
 {'loss': 0.6764231324195862,
  'class_loss': 0.6763547658920288,
  'reg_loss': 68.35871887207031,
  'accuracy': 0.796875,
  'seen': 64.0},
 {'loss': 0.8635827302932739,
  'class_loss': 0.8635047078132629,
  'reg_loss': 77.99657440185547,
  'accuracy': 0.734375,
  'seen': 64.0},
 {'loss': 0.28232598304748535,
  'class_loss': 0.28223732113838196,
  'reg_loss': 88.67436981201172,
  'accuracy': 0.90625,
  'seen': 64.0},
 {'loss': 0.233699