Skip to content

Commit

Permalink
fix: don't include generic types in variants
Browse files Browse the repository at this point in the history
  • Loading branch information
alecthomas committed Jan 11, 2024
1 parent 803c965 commit 0697847
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 1 deletion.
3 changes: 3 additions & 0 deletions bin/hermit.hcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
env = {
"PATH": "${HERMIT_ENV}/scripts:${PATH}",
}
4 changes: 3 additions & 1 deletion cmd/go-check-sumtype/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"flag"
"log"
"os"
"strings"
Expand All @@ -11,7 +12,8 @@ import (

func main() {
log.SetFlags(0)
if len(os.Args) < 2 {
flag.Parse()
if len(flag.Args()) < 1 {
log.Fatalf("Usage: sumtype <packages>\n")
}
args := os.Args[1:]
Expand Down
1 change: 1 addition & 0 deletions decl.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) {
}
pos = pkg.Fset.Position(tspec.Pos())
decl := sumTypeDecl{Package: pkg, TypeName: tspec.Name.Name, Pos: pos}
debugf("found sum type decl: %s.%s", decl.Package.PkgPath, decl.TypeName)
decls = append(decls, decl)
break
}
Expand Down
16 changes: 16 additions & 0 deletions def.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
package gochecksumtype

import (
"flag"
"fmt"
"go/token"
"go/types"
"log"
)

var debug = flag.Bool("debug", false, "enable debug logging")

func debugf(format string, args ...interface{}) {
if *debug {
log.Printf(format, args...)
}
}

// Error as returned by Run()
type Error interface {
error
Expand Down Expand Up @@ -107,6 +117,7 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) {
Decl: decl,
Ty: iface,
}
debugf("searching for variants of %s.%s\n", pkg.Path(), decl.TypeName)
for _, name := range pkg.Scope().Names() {
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
if !ok {
Expand All @@ -116,7 +127,12 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) {
if types.Identical(ty.Underlying(), iface) {
continue
}
// Skip generic types.
if named, ok := ty.(*types.Named); ok && named.TypeParams() != nil {
continue
}
if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) {
debugf(" found variant: %s.%s\n", pkg.Path(), obj.Name())
def.Variants = append(def.Variants, obj)
}
}
Expand Down
7 changes: 7 additions & 0 deletions scripts/go-check-sumtype
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
set -euo pipefail
basedir="$(dirname "$0")/.."
name="$(basename "$0")"
dest="${basedir}/build/devel"
mkdir -p "$dest"
(cd "${basedir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@"
23 changes: 23 additions & 0 deletions testdata/sum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package testdata

//sumtype:decl
type Sum interface{ sum() }

type A struct{}

func (A) sum() {}

type B struct{}

func (B) sum() {}

type C[T any] struct{}

func (C[T]) sum() {}

func SumSwitch(x Sum) {
switch x.(type) {
case A:
case B:
}
}

0 comments on commit 0697847

Please sign in to comment.