In [1]:
from torch_geometric.datasets import MovieLens

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
dataset = MovieLens(root='../data/movielens')
data = dataset[0]
data

HeteroData(
  movie={ x=[9742, 404] },
  user={ num_nodes=610 },
  (user, rates, movie)={
    edge_index=[2, 100836],
    edge_label=[100836],
  }
)

In [16]:
data[('user', 'rates', 'movie')]['edge_index']

tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,    5,  ..., 9462, 9463, 9503]])

In [17]:
data['movie']['x'].shape

torch.Size([9742, 404])

In [18]:
data['user','movie']

{'edge_index': tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,    5,  ..., 9462, 9463, 9503]]), 'edge_label': tensor([4, 4, 4,  ..., 5, 5, 3])}

In [19]:
data['user','movie']['edge_index']

tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,    5,  ..., 9462, 9463, 9503]])

In [20]:
data['user','movie']['edge_index'][0].max(), data['user','movie']['edge_index'][1].max()
# From counts, clear that 0 row is User and 1 row is movies

(tensor(609), tensor(9741))

In [21]:
data['user','movie']['edge_label'].min(), data['user','movie']['edge_label'].max()
# edge_label is the rating that the user gave to a movie

(tensor(0), tensor(5))

In [22]:
# There are 100836 edges so reverse edges have not been added already
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData

# We can leverage the `T.ToUndirected()` transform for this from PyG:
data = T.ToUndirected()(data)
data

HeteroData(
  movie={ x=[9742, 404] },
  user={ num_nodes=610 },
  (user, rates, movie)={
    edge_index=[2, 100836],
    edge_label=[100836],
  },
  (movie, rev_rates, user)={
    edge_index=[2, 100836],
    edge_label=[100836],
  }
)

In [23]:
# Don't need labels for reverse edges
del data['movie', 'rev_rates', 'user'].edge_label

del data['user', 'rates', 'movie'].edge_label
data

HeteroData(
  movie={ x=[9742, 404] },
  user={ num_nodes=610 },
  (user, rates, movie)={ edge_index=[2, 100836] },
  (movie, rev_rates, user)={ edge_index=[2, 100836] }
)

In [24]:


# x = data[('user', 'rates', 'movie')]['edge_label']
# y = torch.ones(len(x), dtype=torch.long)
# data[('user', 'rates', 'movie')]['edge_label'] = y

# x.shape, y.shape

In [34]:
transform = T.RandomLinkSplit(
    num_val=0.1,  # Validation set 10%
    num_test=0.1,  # Test set 10%
    disjoint_train_ratio=0.3,  # 30% of edges for supervision
    neg_sampling_ratio=2.0,  # negative samples to positive samples ratio
    add_negative_train_samples=False, 
    edge_types=("user", "rates", "movie"),
    rev_edge_types=("movie", "rev_rates", "user"),
)

train_data, val_data, test_data = transform(data)
train_data, val_data, test_data

(HeteroData(
   movie={ x=[9742, 404] },
   user={ num_nodes=610 },
   (user, rates, movie)={
     edge_index=[2, 56469],
     edge_label=[24201],
     edge_label_index=[2, 24201],
   },
   (movie, rev_rates, user)={ edge_index=[2, 56469] }
 ),
 HeteroData(
   movie={ x=[9742, 404] },
   user={ num_nodes=610 },
   (user, rates, movie)={
     edge_index=[2, 80670],
     edge_label=[30249],
     edge_label_index=[2, 30249],
   },
   (movie, rev_rates, user)={ edge_index=[2, 80670] }
 ),
 HeteroData(
   movie={ x=[9742, 404] },
   user={ num_nodes=610 },
   (user, rates, movie)={
     edge_index=[2, 90753],
     edge_label=[30249],
     edge_label_index=[2, 30249],
   },
   (movie, rev_rates, user)={ edge_index=[2, 90753] }
 ))

In [35]:
train_data[('user', 'rates', 'movie')]['edge_label']

tensor([1., 1., 1.,  ..., 1., 1., 1.])

In [36]:
sum(train_data[('user', 'rates', 'movie')]['edge_label']==1), train_data[('user', 'rates', 'movie')]['edge_label'].shape

(tensor(24201), torch.Size([24201]))

In [40]:
from torch_geometric.utils import to_networkx
import networkx as nx

G = to_networkx(test_data)
nx.write_gexf(G, '../data/movielens/movielens_test.gexf')

In [37]:
import torch

torch.save(train_data, '../data/movielens/train.pt')
torch.save(val_data, '../data/movielens/test.pt')
torch.save(test_data, '../data/movielens/val.pt')
