In [1]:
import scala.io.Source

val filePath = "data/the_verdict.txt"
val rawText = Source.fromFile(filePath).mkString

println(s"Read ${rawText.length} characters from $filePath")

Read 20479 characters from data/the_verdict.txt


[32mimport [39m[36mscala.io.Source[39m
[36mfilePath[39m: [32mString[39m = [32m"data/the_verdict.txt"[39m
[36mrawText[39m: [32mString[39m = [32m"""I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and established himself in a villa on the Riviera. (Though I rather thought it would have been Rome or Florence.)

"The height of his glory"--that was what the women called it. I can hear Mrs. Gideon Thwing--his last Chicago sitter--deploring his unaccountable abdication. "Of course it's going to send the value of my picture 'way up; but I don't think of that, Mr. Rickham--the loss to Arrt is all I think of." The word, on Mrs. Thwing's lips, multiplied its _rs_ as though they were reflected in an endless vista of mirrors. And it was not only the Mrs. Thwings who mourned. Had not the exquisite Hermia Croft, at the last 

In [2]:
def tokenize(text: String): Vector[String] = {
  val splitBy = """[,.:;?_!"()\']|--|\s"""
  text.split(s"(?<=$splitBy)|(?=$splitBy)").filter(!_.isBlank).toVector
}

println(tokenize("Hello, world. Is this-- a test?"))

Vector(Hello, ,, world, ., Is, this, --, a, test, ?)


defined [32mfunction[39m [36mtokenize[39m

In [3]:
val tokenizedText = tokenize(rawText)

println(s"Extracted ${tokenizedText.length} tokens")

Extracted 4690 tokens


[36mtokenizedText[39m: [32mVector[39m[[32mString[39m] = [33mVector[39m(
  [32m"I"[39m,
  [32m"HAD"[39m,
  [32m"always"[39m,
  [32m"thought"[39m,
  [32m"Jack"[39m,
  [32m"Gisburn"[39m,
  [32m"rather"[39m,
  [32m"a"[39m,
  [32m"cheap"[39m,
  [32m"genius"[39m,
  [32m"--"[39m,
  [32m"though"[39m,
  [32m"a"[39m,
  [32m"good"[39m,
  [32m"fellow"[39m,
  [32m"enough"[39m,
  [32m"--"[39m,
  [32m"so"[39m,
  [32m"it"[39m,
  [32m"was"[39m,
  [32m"no"[39m,
  [32m"great"[39m,
  [32m"surprise"[39m,
  [32m"to"[39m,
  [32m"me"[39m,
  [32m"to"[39m,
  [32m"hear"[39m,
  [32m"that"[39m,
  [32m","[39m,
  [32m"in"[39m,
  [32m"the"[39m,
  [32m"height"[39m,
  [32m"of"[39m,
  [32m"his"[39m,
  [32m"glory"[39m,
  [32m","[39m,
  [32m"he"[39m,
  [32m"had"[39m,
...

In [4]:
val sortedDistinctTokens = tokenizedText.sorted.distinct

println(s"${sortedDistinctTokens.length} distinct tokens in total")

1130 distinct tokens in total


[36msortedDistinctTokens[39m: [32mVector[39m[[32mString[39m] = [33mVector[39m(
  [32m"!"[39m,
  [32m"\""[39m,
  [32m"'"[39m,
  [32m"("[39m,
  [32m")"[39m,
  [32m","[39m,
  [32m"--"[39m,
  [32m"."[39m,
  [32m":"[39m,
  [32m";"[39m,
  [32m"?"[39m,
  [32m"A"[39m,
  [32m"Ah"[39m,
  [32m"Among"[39m,
  [32m"And"[39m,
  [32m"Are"[39m,
  [32m"Arrt"[39m,
  [32m"As"[39m,
  [32m"At"[39m,
  [32m"Be"[39m,
  [32m"Begin"[39m,
  [32m"Burlington"[39m,
  [32m"But"[39m,
  [32m"By"[39m,
  [32m"Carlo"[39m,
  [32m"Chicago"[39m,
  [32m"Claude"[39m,
  [32m"Come"[39m,
  [32m"Croft"[39m,
  [32m"Destroyed"[39m,
  [32m"Devonshire"[39m,
  [32m"Don"[39m,
  [32m"Dubarry"[39m,
  [32m"Emperors"[39m,
  [32m"Florence"[39m,
  [32m"For"[39m,
  [32m"Gallery"[39m,
  [32m"Gideon"[39m,
...

In [5]:
val vocabulary = sortedDistinctTokens.zipWithIndex.toMap

class SimpleTokenizerV1(
  vocabulary: Map[String, Int]
) {
  val inverseVocabulary = vocabulary.map(_.swap)

  def encode(text: String): Vector[Int] = 
    tokenize(text).map(vocabulary(_))

  def tokenize(text: String): Vector[String] = {
    val splitBy = """[,.:;?_!"()\']|--|\s"""
    val tokenizer = s"(?<=$splitBy)|(?=$splitBy)"
    text.split(tokenizer).filter(!_.isBlank).toVector
  }

  def decode(ids: Vector[Int]): String = 
    ids
      .map(inverseVocabulary(_))
      .mkString(" ")
      .replaceAll("\\s+([,.?!\"()\'])", "$1") 
}

[36mvocabulary[39m: [32mMap[39m[[32mString[39m, [32mInt[39m] = [33mHashMap[39m(
  [32m"inevitable"[39m -> [32m571[39m,
  [32m"Monte"[39m -> [32m64[39m,
  [32m"down"[39m -> [32m362[39m,
  [32m"economy"[39m -> [32m377[39m,
  [32m"interesting"[39m -> [32m578[39m,
  [32m"luxury"[39m -> [32m652[39m,
  [32m"serious"[39m -> [32m870[39m,
  [32m"forgotten"[39m -> [32m463[39m,
  [32m"muscles"[39m -> [32m695[39m,
  [32m"beneath"[39m -> [32m215[39m,
  [32m"used"[39m -> [32m1057[39m,
  [32m"eye"[39m -> [32m415[39m,
  [32m"straining"[39m -> [32m934[39m,
  [32m"At"[39m -> [32m18[39m,
  [32m"hooded"[39m -> [32m554[39m,
  [32m"murmur"[39m -> [32m694[39m,
  [32m"adulation"[39m -> [32m133[39m,
  [32m"gloried"[39m -> [32m495[39m,
  [32m"widow"[39m -> [32m1102[39m,
  [32m"panel"[39m -> [32m752[39m,
  [32m"sitters"[39m -> [32m898[39m,
  [32m"quality"[39m -> [32m808[39m,
  [32m"On"[39m -> [32m75[39m,
  [32m

In [6]:
val tokenizer = new SimpleTokenizerV1(vocabulary)

val textToEncode = """"It's the last he painted, you know,"
Mrs. Gisburn said with pardonable pride."""
val ids = tokenizer.encode(textToEncode)
val decodedText = tokenizer.decode(ids)

println(decodedText)

" It' s the last he painted, you know," Mrs. Gisburn said with pardonable pride.


[36mtokenizer[39m: [32mSimpleTokenizerV1[39m = ammonite.$sess.cmd5$Helper$SimpleTokenizerV1@1081f6e2
[36mtextToEncode[39m: [32mString[39m = [32m""""It's the last he painted, you know,"
Mrs. Gisburn said with pardonable pride."""[39m
[36mids[39m: [32mVector[39m[[32mInt[39m] = [33mVector[39m(
  [32m1[39m,
  [32m56[39m,
  [32m2[39m,
  [32m850[39m,
  [32m988[39m,
  [32m602[39m,
  [32m533[39m,
  [32m746[39m,
  [32m5[39m,
  [32m1126[39m,
  [32m596[39m,
  [32m5[39m,
  [32m1[39m,
  [32m67[39m,
  [32m7[39m,
  [32m38[39m,
  [32m851[39m,
  [32m1108[39m,
  [32m754[39m,
  [32m793[39m,
  [32m7[39m
)
[36mdecodedText[39m: [32mString[39m = [32m"\" It' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride."[39m

In [7]:
val endOfText = "<|endoftext|>"
val vocabularyWithEndOfText = (sortedDistinctTokens :+ endOfText).zipWithIndex.toMap
val (unknownTokenId, unknownToken) = -1 -> "<|unknown|>"

class SimpleTokenizerV2(
  vocabulary: Map[String, Int]
) {
  val inverseVocabulary = vocabulary.map(_.swap)

  def encode(text: String): Vector[Int] = 
    tokenize(text).map(vocabulary.getOrElse(_, unknownTokenId))

  def tokenize(text: String): Vector[String] = {
    val splitBy = """[,.:;?_!"()\']|--|\s"""
    val tokenizer = s"(?<=$splitBy)|(?=$splitBy)"
    text.split(tokenizer).filter(!_.isBlank).toVector
  }

  def decode(ids: Vector[Int]): String = 
    ids
      .map(inverseVocabulary.getOrElse(_, unknownToken))
      .mkString(" ")
      .replaceAll("\\s+([,.?!\"()\'])", "$1") 
}

[36mendOfText[39m: [32mString[39m = [32m"<|endoftext|>"[39m
[36mvocabularyWithEndOfText[39m: [32mMap[39m[[32mString[39m, [32mInt[39m] = [33mHashMap[39m(
  [32m"inevitable"[39m -> [32m571[39m,
  [32m"Monte"[39m -> [32m64[39m,
  [32m"down"[39m -> [32m362[39m,
  [32m"economy"[39m -> [32m377[39m,
  [32m"interesting"[39m -> [32m578[39m,
  [32m"luxury"[39m -> [32m652[39m,
  [32m"serious"[39m -> [32m870[39m,
  [32m"forgotten"[39m -> [32m463[39m,
  [32m"muscles"[39m -> [32m695[39m,
  [32m"beneath"[39m -> [32m215[39m,
  [32m"used"[39m -> [32m1057[39m,
  [32m"eye"[39m -> [32m415[39m,
  [32m"straining"[39m -> [32m934[39m,
  [32m"At"[39m -> [32m18[39m,
  [32m"hooded"[39m -> [32m554[39m,
  [32m"murmur"[39m -> [32m694[39m,
  [32m"adulation"[39m -> [32m133[39m,
  [32m"gloried"[39m -> [32m495[39m,
  [32m"widow"[39m -> [32m1102[39m,
  [32m"panel"[39m -> [32m752[39m,
  [32m"sitters"[39m -> [32m898[39m,
  

In [8]:
val tokenizerV2 = new SimpleTokenizerV2(vocabulary)

val concatenatedText = "Hello, do you like tea?" + s" $endOfText " + "In the sunlit terraces of the palace."

import scala.util.chaining._
println(concatenatedText.pipe(tokenizerV2.encode).pipe(tokenizerV2.decode))

<|unknown|>, do you like tea? <|unknown|> In the sunlit terraces of the <|unknown|>.


[36mtokenizerV2[39m: [32mSimpleTokenizerV2[39m = ammonite.$sess.cmd7$Helper$SimpleTokenizerV2@33489346
[36mconcatenatedText[39m: [32mString[39m = [32m"Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace."[39m
[32mimport [39m[36mscala.util.chaining._[39m

In [9]:
import $ivy.`dev.scalapy::scalapy-core:0.5.3`

Downloading https://repo1.maven.org/maven2/sh/almond/almond-scalapy_2.13/0.14.0-RC15/almond-scalapy_2.13-0.14.0-RC15.pom
Downloading https://repo1.maven.org/maven2/dev/scalapy/scalapy-core_2.13/0.5.3/scalapy-core_2.13-0.5.3.pom
Downloaded https://repo1.maven.org/maven2/sh/almond/almond-scalapy_2.13/0.14.0-RC15/almond-scalapy_2.13-0.14.0-RC15.pom
Downloaded https://repo1.maven.org/maven2/dev/scalapy/scalapy-core_2.13/0.5.3/scalapy-core_2.13-0.5.3.pom
Downloading https://repo1.maven.org/maven2/dev/scalapy/scalapy-macros_2.13/0.5.3/scalapy-macros_2.13-0.5.3.pom
Downloaded https://repo1.maven.org/maven2/dev/scalapy/scalapy-macros_2.13/0.5.3/scalapy-macros_2.13-0.5.3.pom
Downloading https://repo1.maven.org/maven2/dev/scalapy/scalapy-macros_2.13/0.5.3/scalapy-macros_2.13-0.5.3.jar
Downloading https://repo1.maven.org/maven2/dev/scalapy/scalapy-core_2.13/0.5.3/scalapy-core_2.13-0.5.3-sources.jar
Downloading https://repo1.maven.org/maven2/dev/scalapy/scalapy-macros_2.13/0.5.3/scalapy-macros_2.1

[32mimport [39m[36m$ivy.$[39m

In [10]:
import $file.Magic

Compiling /workspace/chapter2/Magic.sc


[32mimport [39m[36m$file.$[39m

In [11]:
Magic.!("pip", "install", "tiktoken==0.7.*")

Collecting tiktoken==0.7.*
  Downloading tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (6.6 kB)
Collecting regex>=2022.1.18 (from tiktoken==0.7.*)
  Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (40 kB)
Downloading tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 10.8 MB/s eta 0:00:00
Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (794 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 795.0/795.0 kB 33.0 MB/s eta 0:00:00
Installing collected packages: regex, tiktoken




Successfully installed regex-2024.11.6 tiktoken-0.7.0


In [12]:
import me.shadaj.scalapy.py

val tiktoken = py.module("tiktoken")

[32mimport [39m[36mme.shadaj.scalapy.py[39m
[36mtiktoken[39m: [32mpy[39m.[32mModule[39m = <module 'tiktoken' from '/usr/local/lib/python3.12/site-packages/tiktoken/__init__.py'>

In [13]:
import py.SeqConverters

val tiktokenizer = tiktoken.get_encoding("gpt2")
val tiktext = s"Hello, do you like tea? $endOfText In the sunlit terraces of someunknownPlace"
val allowedSpecial = py.Dynamic.global.set(Seq(endOfText).toPythonProxy)
val tiktokens = tiktokenizer.encode(tiktext, allowed_special = allowedSpecial).as[Vector[Int]]
val decodedTiktokens = tiktokenizer.decode(tiktokens.toPythonProxy).as[String]

println(decodedTiktokens)

Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace


[32mimport [39m[36mpy.SeqConverters[39m
[36mtiktokenizer[39m: [32mpy[39m.[32mDynamic[39m = <Encoding 'gpt2'>
[36mtiktext[39m: [32mString[39m = [32m"Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace"[39m
[36mallowedSpecial[39m: [32mpy[39m.[32mDynamic[39m = {'<|endoftext|>'}
[36mtiktokens[39m: [32mVector[39m[[32mInt[39m] = [33mVector[39m(
  [32m15496[39m,
  [32m11[39m,
  [32m466[39m,
  [32m345[39m,
  [32m588[39m,
  [32m8887[39m,
  [32m30[39m,
  [32m220[39m,
  [32m50256[39m,
  [32m554[39m,
  [32m262[39m,
  [32m4252[39m,
  [32m18250[39m,
  [32m8812[39m,
  [32m2114[39m,
  [32m286[39m,
  [32m617[39m,
  [32m34680[39m,
  [32m27271[39m
)
[36mdecodedTiktokens[39m: [32mString[39m = [32m"Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace"[39m

In [14]:
// Exercise 2.1
val unknownWords = "Akwirw ier"
println(s"Input: $unknownWords")

val encodedUnknownWords = tiktokenizer.encode(unknownWords).as[Vector[Int]]
val encoding = encodedUnknownWords.map(int => int -> tiktokenizer.decode(Seq(int).toPythonProxy).as[String]).toMap

println("Encoding: ")
encoding.toList.sortBy { case (_, subword) => subword }.foreach { case (int, subword) =>
  println(s"  \"$subword\" -> $int")
}

val decodedUnknownWords = tiktokenizer.decode(encodedUnknownWords.toPythonProxy).as[String]
println(s"Decoded: $decodedUnknownWords")

Input: Akwirw ier
Encoding: 
  " " -> 220
  "Ak" -> 33901
  "ier" -> 959
  "ir" -> 343
  "w" -> 86
Decoded: Akwirw ier


[36munknownWords[39m: [32mString[39m = [32m"Akwirw ier"[39m
[36mencodedUnknownWords[39m: [32mVector[39m[[32mInt[39m] = [33mVector[39m([32m33901[39m, [32m86[39m, [32m343[39m, [32m86[39m, [32m220[39m, [32m959[39m)
[36mencoding[39m: [32mMap[39m[[32mInt[39m, [32mString[39m] = [33mHashMap[39m(
  [32m220[39m -> [32m" "[39m,
  [32m33901[39m -> [32m"Ak"[39m,
  [32m343[39m -> [32m"ir"[39m,
  [32m86[39m -> [32m"w"[39m,
  [32m959[39m -> [32m"ier"[39m
)
[36mdecodedUnknownWords[39m: [32mString[39m = [32m"Akwirw ier"[39m

In [15]:
val rawTextTokenized = tiktokenizer.encode(rawText).as[Vector[Int]]

println(s"Raw text token count: ${rawTextTokenized.length}")

val rawTextTokensSampled = rawTextTokenized.drop(50)

val contextSize = 4

(1 to contextSize).foreach { size =>
  val context = rawTextTokensSampled.take(size)
  val desired = rawTextTokensSampled(size)
  println(s"${tiktokenizer.decode(context.toPythonProxy)} ---> ${tiktokenizer.decode(Seq(desired).toPythonProxy)}")
}

Raw text token count: 5145
 and --->  established
 and established --->  himself
 and established himself --->  in
 and established himself in --->  a


[36mrawTextTokenized[39m: [32mVector[39m[[32mInt[39m] = [33mVector[39m(
  [32m40[39m,
  [32m367[39m,
  [32m2885[39m,
  [32m1464[39m,
  [32m1807[39m,
  [32m3619[39m,
  [32m402[39m,
  [32m271[39m,
  [32m10899[39m,
  [32m2138[39m,
  [32m257[39m,
  [32m7026[39m,
  [32m15632[39m,
  [32m438[39m,
  [32m2016[39m,
  [32m257[39m,
  [32m922[39m,
  [32m5891[39m,
  [32m1576[39m,
  [32m438[39m,
  [32m568[39m,
  [32m340[39m,
  [32m373[39m,
  [32m645[39m,
  [32m1049[39m,
  [32m5975[39m,
  [32m284[39m,
  [32m502[39m,
  [32m284[39m,
  [32m3285[39m,
  [32m326[39m,
  [32m11[39m,
  [32m287[39m,
  [32m262[39m,
  [32m6001[39m,
  [32m286[39m,
  [32m465[39m,
  [32m13476[39m,
...
[36mrawTextTokensSampled[39m: [32mVector[39m[[32mInt[39m] = [33mVector[39m(
  [32m290[39m,
  [32m4920[39m,
  [32m2241[39m,
  [32m287[39m,
  [32m257[39m,
  [32m4489[39m,
  [32m64[39m,
  [32m319[39m,
  [32m262[39m,
  [32m34686[

In [16]:
Magic.!("pip", "install", "torch==2.5.*")

Collecting torch==2.5.*
  Downloading torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl.metadata (28 kB)
Collecting filelock (from torch==2.5.*)
  Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from torch==2.5.*)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch==2.5.*)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting sympy==1.13.1 (from torch==2.5.*)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch==2.5.*)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl (91.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 91.8/91.8 MB 5.9 MB/s eta 0:00:00
Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 18.6 MB/s eta 0:00:00
Downloading filelock-3.16.1-py3-none-any.whl (16 kB)
Downloading fsspec-



In [17]:
import me.shadaj.scalapy.interpreter.CPythonInterpreter

val torch = py.module("torch")

// Workaround to define a class that inherits from a Python class
CPythonInterpreter.execManyLines {
  s"""from torch.utils.data import Dataset
     |
     |class GPTDatasetV1(Dataset):
     |  def __init__(self, input_tokens, target_tokens):
     |    self.input_tokens = input_tokens
     |    self.target_tokens = target_tokens
     |  
     |  def __len__(self):
     |    return len(self.input_tokens)
     |
     |  def __getitem__(self, index):
     |    return self.input_tokens[index], self.target_tokens[index]
     |""".stripMargin
}
type Tokenizer = py.Dynamic
type TorchTensor = py.Dynamic
def GPTDatasetV1(
  text: String,
  tokenizer: Tokenizer,
  maxLength: Int,
  step: Int
): py.Dynamic = {
  val tokens = tokenizer.encode(text).as[Vector[Int]]
  val (inputTokens, outputTokens) = (0 until tokens.length by step).foldLeft(
    (
      Vector.empty[TorchTensor], 
      Vector.empty[TorchTensor]
    )
  ) {
    case ((inputTokens, outputTokens), i) =>
      val inputChunk = tokens.slice(i, i + maxLength)
      val outputChunk = tokens.slice(i + 1, i + 1 + maxLength)
      (
       inputTokens :+ torch.tensor(inputChunk.toPythonProxy), 
       outputTokens :+ torch.tensor(outputChunk.toPythonProxy)
      )
  }
  py.Dynamic.global.GPTDatasetV1(inputTokens.toPythonProxy, outputTokens.toPythonProxy)
}

def createDataLoaderV1(
  text: String, 
  batchSize: Int = 4, 
  maxLength: Int = 256,                       
  step: Int = 128, 
  shuffle: Boolean = true, 
  dropLast: Boolean = true,
  numWorkers: Int = 0 
): py.Dynamic = {
  val tokenizer = tiktoken.get_encoding("gpt2")
  val dataset = GPTDatasetV1(text, tokenizer, maxLength, step)
  torch.utils.data.DataLoader(
    dataset,
    batch_size = batchSize,
    shuffle = shuffle,
    drop_last = dropLast,
    num_workers = numWorkers
  )
}

[32mimport [39m[36mme.shadaj.scalapy.interpreter.CPythonInterpreter[39m
[36mtorch[39m: [32mpy[39m.[32mModule[39m = <module 'torch' from '/usr/local/lib/python3.12/site-packages/torch/__init__.py'>
defined [32mtype[39m [36mTokenizer[39m
defined [32mtype[39m [36mTorchTensor[39m
defined [32mfunction[39m [36mGPTDatasetV1[39m
defined [32mfunction[39m [36mcreateDataLoaderV1[39m

In [18]:
val dataLoader = createDataLoaderV1(
  text = rawText, 
  batchSize = 1, 
  maxLength = 4, 
  step = 1, 
  shuffle = false
)
val dataLoaderIterator = py.Dynamic.global.iter(dataLoader)
val firstBatch = py.Dynamic.global.next(dataLoaderIterator)
val secondBatch = py.Dynamic.global.next(dataLoaderIterator)
println(firstBatch)
println(secondBatch)

[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]
[tensor([[ 367, 2885, 1464, 1807]]), tensor([[2885, 1464, 1807, 3619]])]


[36mdataLoader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xfffeb9e824e0>
[36mdataLoaderIterator[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0xfffeb9e48410>
[36mfirstBatch[39m: [32mpy[39m.[32mDynamic[39m = [tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]
[36msecondBatch[39m: [32mpy[39m.[32mDynamic[39m = [tensor([[ 367, 2885, 1464, 1807]]), tensor([[2885, 1464, 1807, 3619]])]

In [19]:
// Exercise 2.2
def printFirst2Batches(maxLength: Int, step: Int) = {
  println(s"maxLength = $maxLength, step = $step")
  val dataLoader = createDataLoaderV1(
    text = rawText, 
    batchSize = 1, 
    maxLength = maxLength, 
    step = step, 
    shuffle = false
  )
  val dataLoaderIterator = py.Dynamic.global.iter(dataLoader)
  val firstBatch = py.Dynamic.global.next(dataLoaderIterator)
  val secondBatch = py.Dynamic.global.next(dataLoaderIterator)
  println(firstBatch)
  println(secondBatch)
}

printFirst2Batches(maxLength = 2, step = 2)
printFirst2Batches(maxLength = 8, step = 2)

maxLength = 2, step = 2
[tensor([[ 40, 367]]), tensor([[ 367, 2885]])]
[tensor([[2885, 1464]]), tensor([[1464, 1807]])]
maxLength = 8, step = 2
[tensor([[  40,  367, 2885, 1464, 1807, 3619,  402,  271]]), tensor([[  367,  2885,  1464,  1807,  3619,   402,   271, 10899]])]
[tensor([[ 2885,  1464,  1807,  3619,   402,   271, 10899,  2138]]), tensor([[ 1464,  1807,  3619,   402,   271, 10899,  2138,   257]])]


defined [32mfunction[39m [36mprintFirst2Batches[39m

In [20]:
val batchedDataloader = createDataLoaderV1(
  text = rawText, 
  batchSize = 8, 
  maxLength = 4, 
  step = 4, // same as maxLength to prevent overlaps between inputs and outputs
  shuffle = false
)
val Seq(inputTokens, outputTokens) = py.Dynamic.global.next(py.Dynamic.global.iter(batchedDataloader)).as[Seq[TorchTensor]]
println(s"Inputs:\n$inputTokens")
println(s"Outputs:\n$outputTokens")

Inputs:
tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])
Outputs:
tensor([[  367,  2885,  1464,  1807],
        [ 3619,   402,   271, 10899],
        [ 2138,   257,  7026, 15632],
        [  438,  2016,   257,   922],
        [ 5891,  1576,   438,   568],
        [  340,   373,   645,  1049],
        [ 5975,   284,   502,   284],
        [ 3285,   326,    11,   287]])


[36mbatchedDataloader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xfffeb8cc4f50>
[36minputTokens[39m: [32mTorchTensor[39m = tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])
[36moutputTokens[39m: [32mTorchTensor[39m = tensor([[  367,  2885,  1464,  1807],
        [ 3619,   402,   271, 10899],
        [ 2138,   257,  7026, 15632],
        [  438,  2016,   257,   922],
        [ 5891,  1576,   438,   568],
        [  340,   373,   645,  1049],
        [ 5975,   284,   502,   284],
        [ 3285,   326,    11,   287]])

In [21]:
val vocabularySize = 50257 // tiktoken vocabulary size
val outputDimension = 256
val tokenEmbeddingLayer = torch.nn.Embedding(vocabularySize, outputDimension)

val maxLength = 4
val dataLoader = createDataLoaderV1(
  text = rawText, 
  batchSize = 8, 
  maxLength = maxLength, 
  step = maxLength,
  shuffle = false
)
val Seq(inputTokensBatchToEmbed, _) = py.Dynamic.global.next(py.Dynamic.global.iter(dataLoader)).as[Seq[TorchTensor]]
println(s"Input batch to embed:\n$inputTokensBatchToEmbed")

val inputTokensEmbedding = tokenEmbeddingLayer(inputTokensBatchToEmbed)
println(s"Input embedding shape: ${inputTokensEmbedding.shape}")

Input batch to embed:
tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])
Input embedding shape: torch.Size([8, 4, 256])


[36mvocabularySize[39m: [32mInt[39m = [32m50257[39m
[36moutputDimension[39m: [32mInt[39m = [32m256[39m
[36mtokenEmbeddingLayer[39m: [32mpy[39m.[32mDynamic[39m = Embedding(50257, 256)
[36mmaxLength[39m: [32mInt[39m = [32m4[39m
[36mdataLoader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xfffeb978a870>
[36minputTokensBatchToEmbed[39m: [32mTorchTensor[39m = tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])
[36minputTokensEmbedding[39m: [32mpy[39m.[32mDynamic[39m = tensor([[[-7.4044e-02, -1.0514e+00,  8.3380e-01,  ...,  9.6170e-02,
          -1.3377e+00, -9.2099e-01],
         [ 1.1112e+00, -1.1214e+00, -9.7873e-02,  ..., -5.9731e-01,
           1.3466e+00, -1.0

In [22]:
val contextLength = maxLength
val positionalEmbeddingLayer = torch.nn.Embedding(contextLength, outputDimension)
val positionalEmbedding = positionalEmbeddingLayer(torch.arange(contextLength))
println(s"Positional embedding shape: ${positionalEmbedding.shape}")

Positional embedding shape: torch.Size([4, 256])


[36mcontextLength[39m: [32mInt[39m = [32m4[39m
[36mpositionalEmbeddingLayer[39m: [32mpy[39m.[32mDynamic[39m = Embedding(4, 256)
[36mpositionalEmbedding[39m: [32mpy[39m.[32mDynamic[39m = tensor([[ 0.6511,  0.2852, -1.1656,  ...,  0.1865,  0.5942,  0.6211],
        [-1.1493, -1.3659,  1.5158,  ..., -0.3550,  1.1064,  1.6394],
        [ 0.2307, -3.2032, -0.5077,  ..., -1.4381,  0.1972,  0.4176],
        [ 0.6251,  0.4173, -0.6667,  ...,  1.0873,  0.2431,  1.1374]],
       grad_fn=<EmbeddingBackward0>)

In [23]:
import py.PyQuote

val combinedEmbedding = py"$inputTokensEmbedding + $positionalEmbedding"
println(s"Combined embedding shape: ${combinedEmbedding.shape}")

Combined embedding shape: torch.Size([8, 4, 256])


[32mimport [39m[36mpy.PyQuote[39m
[36mcombinedEmbedding[39m: [32mpy[39m.[32mDynamic[39m = tensor([[[ 0.5770, -0.7663, -0.3318,  ...,  0.2826, -0.7435, -0.2999],
         [-0.0382, -2.4873,  1.4180,  ..., -0.9523,  2.4530,  0.5455],
         [-0.0106, -3.3366,  0.2190,  ..., -0.8366, -0.8748,  2.0911],
         [-0.2937,  0.2339, -1.4366,  ...,  3.7317,  0.4581,  0.4576]],

        [[ 1.3456, -1.4687, -0.6613,  ..., -0.5143,  0.0929,  0.1360],
         [-1.0671, -3.3487,  0.2909,  ..., -0.1597,  1.1048,  0.7299],
         [-0.1629, -2.8931,  0.1809,  ...,  1.1297,  1.8244,  1.5784],
         [-0.5000,  1.1761, -1.5980,  ...,  1.9388,  0.8414,  1.1329]],

        [[-1.2082, -1.1292, -1.0032,  ...,  1.1581, -0.3184, -0.2483],
         [-1.4807, -3.2555,  0.0755,  ..., -1.5042,  1.2372,  1.6566],
         [ 1.1231, -2.8787, -1.2553,  ..., -1.4814, -0.2170,  1.0769],
         [ 0.3511,  0.5260, -1.3610,  ...,  0.4302, -1.1225,  0.8980]],

        ...,

        [[ 0.2543,  0.8945, 