In [57]:
import torch
from torch import nn
from torch import optim
import numpy as np
#from tqdm.notebook import tqdm
from tqdm import tqdm

In [58]:
def get_subset(dataset, indices, region):
    subset = []
    for i in indices:
        subset.append((dataset[i][region[0]], dataset[i][region[1]]))
    return subset

def slash_train_val(dataset, percentage):
    total = len(dataset)
    val_data = np.random.choice(list(range(total)),int(total*percentage), replace=False).tolist()
    train_data = list(set(range(total)) - set(val_data))
    return val_data, train_data

In [120]:
# dataSetFile: 2D numpy array, [# of moives, # of attributes]
# dataSetFile[0] --> 211 elements, [0:50] plot embeddings, [50:209] attributes, [209] IMDB rating, [210] Douban rating; Missing data = -1
dataSetFile = np.load("npAttrEmbOvwDoubanR.npy")

isMethod1 = True
if isMethod1:
    dataset = []
    for i in range(len(dataSetFile)):
        dataset.append((np.concatenate([dataSetFile[i,0:50]/1000, dataSetFile[i,50:209],dataSetFile[i,209:210]]), dataSetFile[i,210])) # data normalization
    # IMDBintersectDouban: 1D array, movie indices in dataSetFile that IMDB $\cap$ Douban
    IMDBintersectDouban = np.load("IMDBIntersectDouban.npy")
    dataset_intersect = get_subset(dataset, IMDBintersectDouban, [0,1])

    # IMDBintersectDouban: 1D array, movie indices in dataSetFile that IMDB $-$ Douban
    IMDBDifferenceDouban = np.load("IMDBDifferenceDouban.npy")
    dataset_difference = get_subset(dataset, IMDBDifferenceDouban, [0,1])
    
    dummy_rating = 10
    for i in range(len(dataset_difference)):
        a = 1*dataset_difference[i][0][209:210]
                
        y = list(dataset_difference[i])
        y[1] = a
        dataset_difference[i] = tuple(y)
        dataset_difference[i][0][209:210] = dummy_rating
        #print(dataset_difference[i])



    percentage = 0.15 # percentage of the evaluation set

    # slash the three datasets into training and evaluation set
    train_total, val_total = slash_train_val(dataset, percentage)
    train_intersect, val_intersect = slash_train_val(dataset_intersect, percentage)
    train_difference, val_difference = slash_train_val(dataset_difference, percentage)
else:
    dataset = []
    for i in range(len(dataSetFile)):
        dataset.append((np.concatenate([dataSetFile[i,0:50]/1000, dataSetFile[i,50:209]]), dataSetFile[i,209], dataSetFile[i,210])) # data normalization

    # IMDBintersectDouban: 1D array, movie indices in dataSetFile that IMDB $\cap$ Douban
    IMDBintersectDouban = np.load("IMDBIntersectDouban.npy")
    dataset_intersect = get_subset(dataset, IMDBintersectDouban, [0,2])

    # IMDBintersectDouban: 1D array, movie indices in dataSetFile that IMDB $-$ Douban
    IMDBDifferenceDouban = np.load("IMDBDifferenceDouban.npy")
    dataset_difference = get_subset(dataset, IMDBDifferenceDouban, [0,1])

    percentage = 0.15 # percentage of the evaluation set

    # slash the three datasets into training and evaluation set
    train_total, val_total = slash_train_val(dataset, percentage)
    train_intersect, val_intersect = slash_train_val(dataset_intersect, percentage)
    train_difference, val_difference = slash_train_val(dataset_difference, percentage)

# sanity check
print(len(train_total))
print(len(val_total))

1276
7234


In [60]:
# The model we use
class NewNet(nn.Module):
    def __init__(self):
        super(NewNet, self).__init__()
        exp = 128
        emb1 = 10
        emb2 = 10
        self.fc11 = nn.Linear(50, emb1)
        self.fc12 = nn.Linear(159, emb2)
        self.fc2 = nn.Linear(emb1+emb2, exp)
        self.fc3 = nn.Linear(exp, exp)
        self.fc4 = nn.Linear(exp,10)
        self.relu = nn.LeakyReLU()
        self.tail = nn.Softmax(dim=1)
        self.drop = torch.nn.Dropout(0.0)
    def forward(self, x):
        x1 = self.drop(self.relu(self.fc11(x[:,:50])))
        x2 = self.drop(self.relu(self.fc12(x[:,50:])))
        x = torch.cat([x1,x2],dim=1)
        x = self.drop(self.relu(self.fc2(x)))
        x = self.drop(self.relu(self.fc3(x)))
        return (self.tail(self.fc4(x)) * (torch.Tensor(list(range(1,11))).to(x.device))).sum(dim=1)

In [61]:
# The model we use
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        exp = 128
        emb1 = 10
        emb2 = 20 
        self.fc11 = torch.nn.Linear(50, emb1)
        self.fc12 = torch.nn.Linear(159, emb2)
        self.fc13 = torch.nn.Linear(1, 1)
        
        self.fc2 = nn.Linear(emb1+emb2+1, exp)
        self.fc3 = nn.Linear(exp, exp)
        self.fc4 = nn.Linear(exp,10)
        self.drop = torch.nn.Dropout(0.0)
        self.tail = torch.nn.Softmax(dim=1)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.drop(self.relu(self.fc11(x[:,:50])))
        x2 = self.drop(self.relu(self.fc12(x[:,50:209])))
        x3 = self.drop(self.fc13(x[:,209:210]))

        x = torch.cat([x1,x2,x3],dim=1)
        x = self.drop(self.relu(self.fc2(x)))
        x = self.drop(self.relu(self.fc3(x)))
        return (self.tail(self.fc4(x)) * (torch.Tensor(list(range(1,11))).to(x.device))).sum(dim=1)

In [62]:
# Training and testing functions for pytorch
def train(epochs):
    loss_min = 10000
    iter_loader = tqdm(range(epochs))
    for _ in iter_loader:
        running_loss = 0.
        i = 0
        for data, labels in train_loader:
            # data = expand_data(data)
            data, labels = data.float().to(device), labels.to(device)
            # print(data[0], labels[0])
            # break
            optimizer.zero_grad()
            output = model(data)
            # print(output.size())
            loss = criteria(output, labels.view(output.size()).float())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            i += 1
        # print(running_loss/i)
        scheduler.step()
        val_loss = val()
        iter_loader.set_description(f"{running_loss/i:.4f}, {val_loss:.4f}, {loss_min:.4f}")
        if val_loss < loss_min:
            loss_min = val_loss

def val():
    running_loss = 0.
    i = 0
    for data, labels in val_loader:
        # data = expand_data(data)
        data, labels = data.float().to(device), labels.to(device)
        output = model(data)
        loss = criteria(output, labels.view(output.size()).float())
        running_loss += loss.item()
        i+=1
    return running_loss/i

In [68]:
## Baseline

In [63]:
# training and validation loader definition
train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_intersect, train_intersect), batch_size=1024, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_intersect, val_intersect), batch_size=1024, shuffle=False)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=0.)
# optimizer = optim.Adam(model.parameters(), 0.1)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40, 60, 80], 0.8)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [300, 600, 900], 0.1)
criteria = nn.MSELoss()
train(1000)



  0%|          | 0/1000 [00:00<?, ?it/s][A[A

3.9583, 0.7869, 10000.0000:   0%|          | 0/1000 [00:00<?, ?it/s][A[A

0.7999, 1.4742, 0.7869:   0%|          | 0/1000 [00:00<?, ?it/s]    [A[A

1.8107, 0.7309, 0.7869:   0%|          | 0/1000 [00:00<?, ?it/s][A[A

1.8107, 0.7309, 0.7869:   0%|          | 3/1000 [00:00<00:33, 29.95it/s][A[A

0.8201, 1.0580, 0.7309:   0%|          | 3/1000 [00:00<00:33, 29.95it/s][A[A

1.0121, 1.0275, 0.7309:   0%|          | 3/1000 [00:00<00:33, 29.95it/s][A[A

1.0035, 0.7953, 0.7309:   0%|          | 3/1000 [00:00<00:33, 29.95it/s][A[A

1.0035, 0.7953, 0.7309:   1%|          | 6/1000 [00:00<00:33, 29.41it/s][A[A

0.8845, 0.8119, 0.7309:   1%|          | 6/1000 [00:00<00:33, 29.41it/s][A[A

0.9899, 0.8248, 0.7309:   1%|          | 6/1000 [00:00<00:33, 29.41it/s][A[A

1.0115, 0.8003, 0.7309:   1%|          | 6/1000 [00:00<00:33, 29.41it/s][A[A

0.8979, 0.9278, 0.7309:   1%|          | 6/1000 [00:00<00:33, 29.41it/s][A[A

0.8979

0.9269, 0.9065, 0.7309:   7%|▋         | 72/1000 [00:02<00:30, 30.22it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 76/1000 [00:02<00:30, 30.64it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 76/1000 [00:02<00:30, 30.64it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 76/1000 [00:02<00:30, 30.64it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 76/1000 [00:02<00:30, 30.64it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 76/1000 [00:02<00:30, 30.64it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 80/1000 [00:02<00:29, 31.46it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 80/1000 [00:02<00:29, 31.46it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 80/1000 [00:02<00:29, 31.46it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 80/1000 [00:02<00:29, 31.46it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 80/1000 [00:02<00:29, 31.46it/s][A[A

0.9269, 0.9065, 0.7309:   8%|▊         | 84/1000 [00:02<00:28, 32.25it/s][A[A

0.9269, 0.9065, 0.7309:   8%

0.9269, 0.9065, 0.7309:  15%|█▌        | 152/1000 [00:04<00:27, 31.25it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 156/1000 [00:04<00:26, 31.92it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 156/1000 [00:05<00:26, 31.92it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 156/1000 [00:05<00:26, 31.92it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 156/1000 [00:05<00:26, 31.92it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 156/1000 [00:05<00:26, 31.92it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 160/1000 [00:05<00:26, 31.61it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 160/1000 [00:05<00:26, 31.61it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 160/1000 [00:05<00:26, 31.61it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 160/1000 [00:05<00:26, 31.61it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▌        | 160/1000 [00:05<00:26, 31.61it/s][A[A

0.9269, 0.9065, 0.7309:  16%|█▋        | 164/1000 [00:05<00:26, 31.09it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  23%|██▎       | 232/1000 [00:07<00:25, 30.59it/s][A[A

0.9269, 0.9065, 0.7309:  23%|██▎       | 232/1000 [00:07<00:25, 30.59it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▎       | 236/1000 [00:07<00:24, 31.16it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▎       | 236/1000 [00:07<00:24, 31.16it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▎       | 236/1000 [00:07<00:24, 31.16it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▎       | 236/1000 [00:07<00:24, 31.16it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▎       | 236/1000 [00:07<00:24, 31.16it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▍       | 240/1000 [00:07<00:23, 31.84it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▍       | 240/1000 [00:07<00:23, 31.84it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▍       | 240/1000 [00:07<00:23, 31.84it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▍       | 240/1000 [00:07<00:23, 31.84it/s][A[A

0.9269, 0.9065, 0.7309:  24%|██▍       | 240/1000 [00:07<00:23, 31.84it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  31%|███       | 312/1000 [00:10<00:23, 29.33it/s][A[A

0.9269, 0.9065, 0.7309:  31%|███       | 312/1000 [00:10<00:23, 29.33it/s][A[A

0.9269, 0.9065, 0.7309:  31%|███       | 312/1000 [00:10<00:23, 29.33it/s][A[A

0.9269, 0.9065, 0.7309:  31%|███       | 312/1000 [00:10<00:23, 29.33it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 316/1000 [00:10<00:22, 29.77it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 316/1000 [00:10<00:22, 29.77it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 316/1000 [00:10<00:22, 29.77it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 316/1000 [00:10<00:22, 29.77it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 316/1000 [00:10<00:22, 29.77it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 320/1000 [00:10<00:22, 30.09it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 320/1000 [00:10<00:22, 30.09it/s][A[A

0.9269, 0.9065, 0.7309:  32%|███▏      | 320/1000 [00:10<00:22, 30.09it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  39%|███▉      | 392/1000 [00:12<00:17, 33.93it/s][A[A

0.9269, 0.9065, 0.7309:  39%|███▉      | 392/1000 [00:12<00:17, 33.93it/s][A[A

0.9269, 0.9065, 0.7309:  39%|███▉      | 392/1000 [00:12<00:17, 33.93it/s][A[A

0.9269, 0.9065, 0.7309:  39%|███▉      | 392/1000 [00:12<00:17, 33.93it/s][A[A

0.9269, 0.9065, 0.7309:  39%|███▉      | 392/1000 [00:12<00:17, 33.93it/s][A[A

0.9269, 0.9065, 0.7309:  40%|███▉      | 396/1000 [00:12<00:18, 33.54it/s][A[A

0.9269, 0.9065, 0.7309:  40%|███▉      | 396/1000 [00:12<00:18, 33.54it/s][A[A

0.9269, 0.9065, 0.7309:  40%|███▉      | 396/1000 [00:12<00:18, 33.54it/s][A[A

0.9269, 0.9065, 0.7309:  40%|███▉      | 396/1000 [00:12<00:18, 33.54it/s][A[A

0.9269, 0.9065, 0.7309:  40%|███▉      | 396/1000 [00:12<00:18, 33.54it/s][A[A

0.9269, 0.9065, 0.7309:  40%|████      | 400/1000 [00:12<00:17, 33.75it/s][A[A

0.9269, 0.9065, 0.7309:  40%|████      | 400/1000 [00:12<00:17, 33.75it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  47%|████▋     | 468/1000 [00:14<00:16, 32.02it/s][A[A

0.9269, 0.9065, 0.7309:  47%|████▋     | 472/1000 [00:14<00:16, 32.75it/s][A[A

0.9269, 0.9065, 0.7309:  47%|████▋     | 472/1000 [00:14<00:16, 32.75it/s][A[A

0.9269, 0.9065, 0.7309:  47%|████▋     | 472/1000 [00:14<00:16, 32.75it/s][A[A

0.9269, 0.9065, 0.7309:  47%|████▋     | 472/1000 [00:14<00:16, 32.75it/s][A[A

0.9269, 0.9065, 0.7309:  47%|████▋     | 472/1000 [00:15<00:16, 32.75it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 476/1000 [00:15<00:16, 32.64it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 476/1000 [00:15<00:16, 32.64it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 476/1000 [00:15<00:16, 32.64it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 476/1000 [00:15<00:16, 32.64it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 476/1000 [00:15<00:16, 32.64it/s][A[A

0.9269, 0.9065, 0.7309:  48%|████▊     | 480/1000 [00:15<00:15, 32.59it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  55%|█████▍    | 548/1000 [00:17<00:13, 34.26it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▍    | 548/1000 [00:17<00:13, 34.26it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▌    | 552/1000 [00:17<00:13, 34.36it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▌    | 552/1000 [00:17<00:13, 34.36it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▌    | 552/1000 [00:17<00:13, 34.36it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▌    | 552/1000 [00:17<00:13, 34.36it/s][A[A

0.9269, 0.9065, 0.7309:  55%|█████▌    | 552/1000 [00:17<00:13, 34.36it/s][A[A

0.9269, 0.9065, 0.7309:  56%|█████▌    | 556/1000 [00:17<00:12, 34.35it/s][A[A

0.9269, 0.9065, 0.7309:  56%|█████▌    | 556/1000 [00:17<00:12, 34.35it/s][A[A

0.9269, 0.9065, 0.7309:  56%|█████▌    | 556/1000 [00:17<00:12, 34.35it/s][A[A

0.9269, 0.9065, 0.7309:  56%|█████▌    | 556/1000 [00:17<00:12, 34.35it/s][A[A

0.9269, 0.9065, 0.7309:  56%|█████▌    | 556/1000 [00:17<00:12, 34.35it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  63%|██████▎   | 628/1000 [00:19<00:11, 32.08it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 628/1000 [00:19<00:11, 32.08it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 628/1000 [00:19<00:11, 32.08it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 632/1000 [00:19<00:11, 32.50it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 632/1000 [00:19<00:11, 32.50it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 632/1000 [00:19<00:11, 32.50it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 632/1000 [00:19<00:11, 32.50it/s][A[A

0.9269, 0.9065, 0.7309:  63%|██████▎   | 632/1000 [00:19<00:11, 32.50it/s][A[A

0.9269, 0.9065, 0.7309:  64%|██████▎   | 636/1000 [00:19<00:11, 32.92it/s][A[A

0.9269, 0.9065, 0.7309:  64%|██████▎   | 636/1000 [00:19<00:11, 32.92it/s][A[A

0.9269, 0.9065, 0.7309:  64%|██████▎   | 636/1000 [00:19<00:11, 32.92it/s][A[A

0.9269, 0.9065, 0.7309:  64%|██████▎   | 636/1000 [00:19<00:11, 32.92it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  71%|███████   | 708/1000 [00:21<00:08, 33.95it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 708/1000 [00:21<00:08, 33.95it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 708/1000 [00:21<00:08, 33.95it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 708/1000 [00:21<00:08, 33.95it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 712/1000 [00:21<00:08, 34.31it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 712/1000 [00:21<00:08, 34.31it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 712/1000 [00:21<00:08, 34.31it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 712/1000 [00:22<00:08, 34.31it/s][A[A

0.9269, 0.9065, 0.7309:  71%|███████   | 712/1000 [00:22<00:08, 34.31it/s][A[A

0.9269, 0.9065, 0.7309:  72%|███████▏  | 716/1000 [00:22<00:08, 34.41it/s][A[A

0.9269, 0.9065, 0.7309:  72%|███████▏  | 716/1000 [00:22<00:08, 34.41it/s][A[A

0.9269, 0.9065, 0.7309:  72%|███████▏  | 716/1000 [00:22<00:08, 34.41it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  79%|███████▉  | 788/1000 [00:24<00:07, 29.50it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 788/1000 [00:24<00:07, 29.50it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 788/1000 [00:24<00:07, 29.50it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 788/1000 [00:24<00:07, 29.50it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 788/1000 [00:24<00:07, 29.50it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 792/1000 [00:24<00:06, 30.19it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 792/1000 [00:24<00:06, 30.19it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 792/1000 [00:24<00:06, 30.19it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 792/1000 [00:24<00:06, 30.19it/s][A[A

0.9269, 0.9065, 0.7309:  79%|███████▉  | 792/1000 [00:24<00:06, 30.19it/s][A[A

0.9269, 0.9065, 0.7309:  80%|███████▉  | 796/1000 [00:24<00:06, 30.71it/s][A[A

0.9269, 0.9065, 0.7309:  80%|███████▉  | 796/1000 [00:24<00:06, 30.71it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  86%|████████▋ | 864/1000 [00:26<00:04, 29.22it/s][A[A

0.9269, 0.9065, 0.7309:  86%|████████▋ | 864/1000 [00:27<00:04, 29.22it/s][A[A

0.9269, 0.9065, 0.7309:  86%|████████▋ | 864/1000 [00:27<00:04, 29.22it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 867/1000 [00:27<00:04, 29.39it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 867/1000 [00:27<00:04, 29.39it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 867/1000 [00:27<00:04, 29.39it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 867/1000 [00:27<00:04, 29.39it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 867/1000 [00:27<00:04, 29.39it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 871/1000 [00:27<00:04, 30.44it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 871/1000 [00:27<00:04, 30.44it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 871/1000 [00:27<00:04, 30.44it/s][A[A

0.9269, 0.9065, 0.7309:  87%|████████▋ | 871/1000 [00:27<00:04, 30.44it/s][A[A

0.9269, 0.9065, 

0.9269, 0.9065, 0.7309:  94%|█████████▍| 943/1000 [00:29<00:01, 29.23it/s][A[A

0.9269, 0.9065, 0.7309:  94%|█████████▍| 943/1000 [00:29<00:01, 29.23it/s][A[A

0.9269, 0.9065, 0.7309:  94%|█████████▍| 943/1000 [00:29<00:01, 29.23it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▍| 946/1000 [00:29<00:01, 27.24it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▍| 946/1000 [00:29<00:01, 27.24it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▍| 946/1000 [00:29<00:01, 27.24it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▍| 946/1000 [00:29<00:01, 27.24it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▍| 946/1000 [00:29<00:01, 27.24it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▌| 950/1000 [00:29<00:01, 28.14it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▌| 950/1000 [00:29<00:01, 28.14it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▌| 950/1000 [00:29<00:01, 28.14it/s][A[A

0.9269, 0.9065, 0.7309:  95%|█████████▌| 950/1000 [00:29<00:01, 28.14it/s][A[A

0.9269, 0.9065, 

In [69]:
## Source

In [121]:
train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_difference, train_difference), batch_size=1024, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_difference, val_difference), batch_size=1024, shuffle=False)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=0.)
# optimizer = optim.Adam(model.parameters(), 0.1)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40, 60, 80], 0.8)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [300, 600, 900], 0.1)
criteria = nn.MSELoss()
train(150)

0.6646, 0.9130, 0.8063: 100%|██████████| 150/150 [00:09<00:00, 15.03it/s]


In [122]:
train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_intersect, train_intersect), batch_size=1024, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_intersect, val_intersect), batch_size=1024, shuffle=False)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=0.)
# optimizer = optim.Adam(model.parameters(), 0.1)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40, 60, 80], 0.8)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [300, 600, 900], 0.1)
criteria = nn.MSELoss()
train(400)

0.5841, 0.6083, 0.6083: 100%|██████████| 400/400 [00:10<00:00, 37.92it/s]


In [11]:
from sklearn import svm
TS = torch.utils.data.Subset(dataset, train_data)
VS = torch.utils.data.Subset(dataset, val_data)
train_loader = torch.utils.data.DataLoader(TS, batch_size=len(TS), shuffle=True)
val_loader = torch.utils.data.DataLoader(VS, batch_size=len(VS), shuffle=False)
for X,y in train_loader:
    pass
regr = svm.SVR(kernel="poly", degree=10, gamma="scale", tol=1e-4, verbose=True)
regr.fit(X, y)

ModuleNotFoundError: No module named 'sklearn'

In [19]:
for X,y in val_loader:
    pass
T = regr.predict(X).astype("float32")
criteria(torch.Tensor(T), y)

NameError: name 'val_loader' is not defined

In [20]:

for i in dataset_intersect:
    a = i[-1]
    
    break

NameError: name 'dataset_intersect' is not defined